Files
factorio-learning-environment/fle/env/gym_env/run_eval.py
2025-08-26 00:47:13 -07:00

131 lines
4.5 KiB
Python

import asyncio
import json
import multiprocessing
import os
import gym
import importlib.resources
from dotenv import load_dotenv
from fle.env.gym_env.config import GymEvalConfig, GymRunConfig
from fle.env.gym_env.observation_formatter import BasicObservationFormatter
from fle.env.gym_env.system_prompt_formatter import SystemPromptFormatter
from fle.env.gym_env.registry import get_environment_info, list_available_environments
from fle.env.gym_env.trajectory_runner import GymTrajectoryRunner
from fle.agents.gym_agent import GymAgent
from fle.commons.db_client import create_db_client, get_next_version
from fle.eval.tasks import TaskFactory
from fle.env.utils.controller_loader.system_prompt_generator import (
SystemPromptGenerator,
)
load_dotenv()
def get_validated_run_configs(run_config_location: str) -> list[GymRunConfig]:
"""Read and validate run configurations from file"""
# Read run config
with open(run_config_location, "r") as f:
run_configs_raw = json.load(f)
run_configs = [GymRunConfig(**config) for config in run_configs_raw]
# Validate that all environment IDs exist in the registry
available_envs = list_available_environments()
for run_config in run_configs:
if run_config.env_id not in available_envs:
raise ValueError(
f"Environment ID '{run_config.env_id}' not found in registry. Available environments: {available_envs}"
)
return run_configs
def run_process(run_idx: int, config: GymEvalConfig):
"""Run a single gym evaluation process"""
asyncio.run(run_trajectory(run_idx, config))
async def run_trajectory(run_idx: int, config: GymEvalConfig):
"""Run a single gym evaluation process"""
db_client = await create_db_client()
gym_env = gym.make(config.env_id, run_idx=run_idx)
log_dir = os.path.join(".fle", "trajectory_logs", f"v{config.version}")
runner = GymTrajectoryRunner(
config=config,
gym_env=gym_env,
db_client=db_client,
log_dir=log_dir,
process_id=run_idx,
)
await runner.run()
await db_client.cleanup()
async def main(config_path):
# Read and validate run configurations
run_configs = get_validated_run_configs(config_path)
# Get starting version number for new runs
base_version = await get_next_version()
version_offset = 0
# Create and start processes
processes = []
for run_idx, run_config in enumerate(run_configs):
# Get environment info from registry
env_info = get_environment_info(run_config.env_id)
if env_info is None:
raise ValueError(f"Could not get environment info for {run_config.env_id}")
task = TaskFactory.create_task(env_info["task_config_path"])
generator = SystemPromptGenerator(str(importlib.resources.files("fle") / "env"))
# Create agents and their agent cards
agents = []
agent_cards = []
num_agents = env_info["num_agents"]
for agent_idx in range(num_agents):
system_prompt = generator.generate_for_agent(
agent_idx=agent_idx, num_agents=num_agents
)
agent = GymAgent(
model=run_config.model,
system_prompt=system_prompt,
task=task,
agent_idx=agent_idx,
observation_formatter=BasicObservationFormatter(include_research=False),
system_prompt_formatter=SystemPromptFormatter(),
)
agents.append(agent)
# Create agent card for a2a support
agent_card = agent.get_agent_card()
agent_cards.append(agent_card)
# Set version
version = (
run_config.version
if run_config.version is not None
else base_version + version_offset
)
version_offset += 1
# Create eval config with agent cards for a2a support
config = GymEvalConfig(
agents=agents,
version=version,
version_description=f"model:{run_config.model}\ntype:{task.task_key}\nnum_agents:{num_agents}",
task=task,
agent_cards=agent_cards,
env_id=run_config.env_id,
)
# Ensure agent cards are properly set for a2a functionality
assert config.agent_cards is not None
# Start process
p = multiprocessing.Process(target=run_process, args=(run_idx, config))
p.start()
processes.append(p)
# Wait for all processes to complete
for p in processes:
p.join()