Multiagent gym env and minor fixes (#299)

* everything working

* default fast
This commit is contained in:
Neel Kant
2025-08-13 18:48:32 +02:00
committed by GitHub
parent 3b64e38a13
commit 6042174da4
11 changed files with 80 additions and 53 deletions

View File

@@ -0,0 +1,31 @@
import asyncio
import concurrent.futures
def _run_async_in_new_thread(coro):
"""Run an async coroutine in a new thread with its own event loop"""
def run_in_thread():
new_loop = asyncio.new_event_loop()
asyncio.set_event_loop(new_loop)
try:
return new_loop.run_until_complete(coro)
finally:
new_loop.close()
asyncio.set_event_loop(None)
with concurrent.futures.ThreadPoolExecutor() as executor:
future = executor.submit(run_in_thread)
return future.result()
def run_async_safely(coro):
"""Run an async coroutine safely, handling both cases where event loop exists or not"""
try:
# Check if we're already in a running event loop
asyncio.get_running_loop()
# If we get here, there's a running loop, so use thread approach
return _run_async_in_new_thread(coro)
except RuntimeError:
# No running event loop, safe to use asyncio.run()
return asyncio.run(coro)

View File

@@ -18,7 +18,9 @@ class GameState:
research: Optional[ResearchState] = field()
timestamp: float = field(default_factory=time.time)
namespaces: List[bytes] = field(default_factory=list)
agent_messages: List[Dict[str, Any]] = field(default_factory=list)
agent_messages: List[Any] = field(
default_factory=list
) # Can be List[Dict] or List[List[Dict]]
@property
def is_multiagent(self) -> bool:
@@ -28,24 +30,24 @@ class GameState:
def num_agents(self) -> int:
return len(self.inventories)
def parse_agent_messages(data: dict) -> List[Dict[str, Any]]:
def parse_agent_messages(data: dict) -> List[Any]:
agent_messages = data.get("agent_messages", [])
if not isinstance(agent_messages, list):
raise ValueError("agent_messages must be a list")
if agent_messages and not all(isinstance(msg, dict) for msg in agent_messages):
if agent_messages and not all(
isinstance(msg, (dict, list)) for msg in agent_messages
):
for idx, message in enumerate(agent_messages):
if isinstance(message, dict):
continue
elif isinstance(message, list):
# Keep the list as-is, but validate its contents
if len(message) > 0:
if isinstance(message[0], dict):
agent_messages[idx] = message[0]
else:
if not all(isinstance(msg, dict) for msg in message):
raise ValueError(
f"agent_messages[{idx}] must be a dictionary or a list of dictionaries, but got {type(message[0])}"
f"agent_messages[{idx}] contains non-dictionary elements"
)
else:
agent_messages[idx] = {}
# Leave the list unchanged - don't convert to single dict
else:
raise ValueError(
f"agent_messages[{idx}] must be a dictionary or a list of dictionaries, but got {type(message)}"

View File

@@ -14,7 +14,6 @@ class GymRunConfig:
env_id: str # Gym environment ID from registry (e.g., "Factorio-iron_ore_throughput_16-v0")
model: str
version: Optional[int] = None
num_agents: int = 1
exit_on_task_success: bool = True
observation_formatter: Optional[BasicObservationFormatter] = None

View File

@@ -2,10 +2,12 @@ import os
from dataclasses import dataclass
from typing import Any, Dict, List, Optional
from fle.env.a2a_instance import A2AFactorioInstance
import gym
import json
from fle.commons.cluster_ips import get_local_container_ips
from fle.commons.asyncio_utils import run_async_safely
from fle.env import FactorioInstance
from fle.env.gym_env.environment import FactorioGymEnv
from fle.eval.tasks import TaskFactory
@@ -52,11 +54,13 @@ class FactorioGymRegistry:
with open(task_file, "r") as f:
task_data = json.load(f)
task_key = task_data.get("config", {}).get("task_key", task_file.stem)
task_type = task_data.get("task_type", "default")
goal_description = task_data.get("config", {}).get(
task_config = task_data.get("config", {})
task_key = task_config.get("task_key", task_file.stem)
task_type = task_config.get("task_type", "default")
goal_description = task_config.get(
"goal_description", f"Task: {task_key}"
)
num_agents = task_config.get("num_agents", 1)
# Register the environment
self.register_environment(
env_id=task_key,
@@ -64,6 +68,7 @@ class FactorioGymRegistry:
task_config_path=str(task_file),
description=goal_description,
task_type=task_type,
num_agents=num_agents,
)
except Exception as e:
@@ -134,32 +139,30 @@ def make_factorio_env(env_spec: GymEnvironmentSpec) -> FactorioGymEnv:
# Create Factorio instance
try:
# Check for external server configuration via environment variables
external_address = os.getenv("FACTORIO_SERVER_ADDRESS")
external_port = os.getenv("FACTORIO_SERVER_PORT")
address = os.getenv("FACTORIO_SERVER_ADDRESS")
tcp_port = os.getenv("FACTORIO_SERVER_PORT")
if external_address and external_port:
# Use external server
instance = FactorioInstance(
address=external_address,
tcp_port=int(external_port),
num_agents=env_spec.num_agents,
)
print(
f"Using external Factorio server at {external_address}:{external_port}"
)
else:
# Fall back to local containers
if not address and not tcp_port:
ips, udp_ports, tcp_ports = get_local_container_ips()
if len(tcp_ports) == 0:
raise RuntimeError("No Factorio containers available")
address, tcp_port = ips[0], tcp_ports[0]
# Use the first available container
instance = FactorioInstance(
address=ips[0],
tcp_port=tcp_ports[0],
num_agents=env_spec.num_agents,
)
print(f"Using local Factorio container at {ips[0]}:{tcp_ports[0]}")
common_kwargs = {
"address": address,
"tcp_port": int(tcp_port),
"num_agents": env_spec.num_agents,
"fast": True,
"cache_scripts": True,
"inventory": {},
"all_technologies_researched": True,
}
print(f"Using local Factorio container at {address}:{tcp_port}")
if env_spec.num_agents > 1:
instance = run_async_safely(A2AFactorioInstance.create(**common_kwargs))
else:
instance = FactorioInstance(**common_kwargs)
instance.speed(10)

View File

@@ -28,15 +28,6 @@ def get_validated_run_configs(run_config_location: str) -> list[GymRunConfig]:
run_configs_raw = json.load(f)
run_configs = [GymRunConfig(**config) for config in run_configs_raw]
# Validate config
num_agents_in_configs = [run_config.num_agents for run_config in run_configs]
if any(num_agents == 1 for num_agents in num_agents_in_configs) and any(
num_agents > 1 for num_agents in num_agents_in_configs
):
raise ValueError(
"Cannot mix single agent and multi agent runs in the same run config file. Please split into separate files."
)
# Validate that all environment IDs exist in the registry
available_envs = list_available_environments()
for run_config in run_configs:
@@ -110,11 +101,10 @@ async def main():
gym_env = gym.make(run_config.env_id)
task = gym_env.unwrapped.task
instance = gym_env.unwrapped.instance
# Create agents and their agent cards
agents = []
agent_cards = []
for agent_idx in range(run_config.num_agents):
for agent_idx in range(instance.num_agents):
system_prompt = instance.get_system_prompt(agent_idx)
agent = GymAgent(
model=run_config.model,
@@ -142,7 +132,7 @@ async def main():
config = GymEvalConfig(
agents=agents,
version=version,
version_description=f"model:{run_config.model}\ntype:{task.task_key}\nnum_agents:{run_config.num_agents}",
version_description=f"model:{run_config.model}\ntype:{task.task_key}\nnum_agents:{instance.num_agents}",
exit_on_task_success=run_config.exit_on_task_success,
task=task,
agent_cards=agent_cards,

2
fle/env/instance.py vendored
View File

@@ -94,7 +94,7 @@ class FactorioInstance:
def __init__(
self,
address=None,
fast=False,
fast=True,
tcp_port=27000,
inventory=None,
cache_scripts=True,

View File

@@ -10,7 +10,8 @@
"trajectory_length": 16,
"holdout_wait_period": 60,
"pre_holdout_wait_period": 60,
"task_key": "iron_plate_throughput_unbounded_steps_show_steps_false",
"task_key": "iron_plate_throughput_multiagent_distrust",
"num_agents": 2,
"show_number_of_steps_left_in_prompt": false}
}

View File

@@ -6,7 +6,8 @@
"trajectory_length": 16,
"holdout_wait_period": 60,
"pre_holdout_wait_period": 60,
"task_key": "iron_plate_throughput_unbounded_steps_show_steps_false",
"task_key": "iron_plate_throughput_multiagent_free",
"num_agents": 2,
"show_number_of_steps_left_in_prompt": false}
}

View File

@@ -10,7 +10,8 @@
"trajectory_length": 16,
"holdout_wait_period": 60,
"pre_holdout_wait_period": 60,
"task_key": "iron_plate_throughput_unbounded_steps_show_steps_false",
"task_key": "iron_plate_throughput_multiagent_impostor",
"num_agents": 2,
"show_number_of_steps_left_in_prompt": false}
}

View File

@@ -26,6 +26,8 @@ class TaskFactory:
}
task_type = input_json["task_type"]
task_config = input_json["config"]
if "num_agents" in task_config:
del task_config["num_agents"]
if task_type in task_type_mapping:
task_class = task_type_mapping[task_type]
return task_class(**task_config)

View File

@@ -77,9 +77,6 @@ class UnboundedThroughputTask(TaskABC):
success=False,
meta={
"achievements": max_achievements,
"nr_of_steps_left": self.trajectory_length
- step_statistics["current_step_id"]
- 1,
},
)