mirror of
https://github.com/JackHopkins/factorio-learning-environment.git
synced 2025-09-06 13:23:58 +00:00
131 lines
4.5 KiB
Python
131 lines
4.5 KiB
Python
import asyncio
|
|
import json
|
|
import multiprocessing
|
|
import os
|
|
|
|
import gym
|
|
import importlib.resources
|
|
from dotenv import load_dotenv
|
|
from fle.env.gym_env.config import GymEvalConfig, GymRunConfig
|
|
from fle.env.gym_env.observation_formatter import BasicObservationFormatter
|
|
from fle.env.gym_env.system_prompt_formatter import SystemPromptFormatter
|
|
from fle.env.gym_env.registry import get_environment_info, list_available_environments
|
|
from fle.env.gym_env.trajectory_runner import GymTrajectoryRunner
|
|
|
|
from fle.agents.gym_agent import GymAgent
|
|
from fle.commons.db_client import create_db_client, get_next_version
|
|
from fle.eval.tasks import TaskFactory
|
|
from fle.env.utils.controller_loader.system_prompt_generator import (
|
|
SystemPromptGenerator,
|
|
)
|
|
|
|
load_dotenv()
|
|
|
|
|
|
def get_validated_run_configs(run_config_location: str) -> list[GymRunConfig]:
|
|
"""Read and validate run configurations from file"""
|
|
# Read run config
|
|
with open(run_config_location, "r") as f:
|
|
run_configs_raw = json.load(f)
|
|
run_configs = [GymRunConfig(**config) for config in run_configs_raw]
|
|
|
|
# Validate that all environment IDs exist in the registry
|
|
available_envs = list_available_environments()
|
|
for run_config in run_configs:
|
|
if run_config.env_id not in available_envs:
|
|
raise ValueError(
|
|
f"Environment ID '{run_config.env_id}' not found in registry. Available environments: {available_envs}"
|
|
)
|
|
|
|
return run_configs
|
|
|
|
|
|
def run_process(run_idx: int, config: GymEvalConfig):
|
|
"""Run a single gym evaluation process"""
|
|
asyncio.run(run_trajectory(run_idx, config))
|
|
|
|
|
|
async def run_trajectory(run_idx: int, config: GymEvalConfig):
|
|
"""Run a single gym evaluation process"""
|
|
db_client = await create_db_client()
|
|
|
|
gym_env = gym.make(config.env_id, run_idx=run_idx)
|
|
|
|
log_dir = os.path.join(".fle", "trajectory_logs", f"v{config.version}")
|
|
runner = GymTrajectoryRunner(
|
|
config=config,
|
|
gym_env=gym_env,
|
|
db_client=db_client,
|
|
log_dir=log_dir,
|
|
process_id=run_idx,
|
|
)
|
|
await runner.run()
|
|
await db_client.cleanup()
|
|
|
|
|
|
async def main(config_path):
|
|
# Read and validate run configurations
|
|
run_configs = get_validated_run_configs(config_path)
|
|
# Get starting version number for new runs
|
|
base_version = await get_next_version()
|
|
version_offset = 0
|
|
|
|
# Create and start processes
|
|
processes = []
|
|
for run_idx, run_config in enumerate(run_configs):
|
|
# Get environment info from registry
|
|
env_info = get_environment_info(run_config.env_id)
|
|
if env_info is None:
|
|
raise ValueError(f"Could not get environment info for {run_config.env_id}")
|
|
task = TaskFactory.create_task(env_info["task_config_path"])
|
|
generator = SystemPromptGenerator(str(importlib.resources.files("fle") / "env"))
|
|
# Create agents and their agent cards
|
|
agents = []
|
|
agent_cards = []
|
|
num_agents = env_info["num_agents"]
|
|
for agent_idx in range(num_agents):
|
|
system_prompt = generator.generate_for_agent(
|
|
agent_idx=agent_idx, num_agents=num_agents
|
|
)
|
|
agent = GymAgent(
|
|
model=run_config.model,
|
|
system_prompt=system_prompt,
|
|
task=task,
|
|
agent_idx=agent_idx,
|
|
observation_formatter=BasicObservationFormatter(include_research=False),
|
|
system_prompt_formatter=SystemPromptFormatter(),
|
|
)
|
|
agents.append(agent)
|
|
|
|
# Create agent card for a2a support
|
|
agent_card = agent.get_agent_card()
|
|
agent_cards.append(agent_card)
|
|
|
|
# Set version
|
|
version = (
|
|
run_config.version
|
|
if run_config.version is not None
|
|
else base_version + version_offset
|
|
)
|
|
version_offset += 1
|
|
# Create eval config with agent cards for a2a support
|
|
config = GymEvalConfig(
|
|
agents=agents,
|
|
version=version,
|
|
version_description=f"model:{run_config.model}\ntype:{task.task_key}\nnum_agents:{num_agents}",
|
|
task=task,
|
|
agent_cards=agent_cards,
|
|
env_id=run_config.env_id,
|
|
)
|
|
# Ensure agent cards are properly set for a2a functionality
|
|
assert config.agent_cards is not None
|
|
|
|
# Start process
|
|
p = multiprocessing.Process(target=run_process, args=(run_idx, config))
|
|
p.start()
|
|
processes.append(p)
|
|
|
|
# Wait for all processes to complete
|
|
for p in processes:
|
|
p.join()
|