mirror of
https://github.com/JackHopkins/factorio-learning-environment.git
synced 2025-09-06 13:23:58 +00:00
Multiagent gym env and minor fixes (#299)
* everything working * default fast
This commit is contained in:
31
fle/commons/asyncio_utils.py
Normal file
31
fle/commons/asyncio_utils.py
Normal file
@@ -0,0 +1,31 @@
|
||||
import asyncio
|
||||
import concurrent.futures
|
||||
|
||||
|
||||
def _run_async_in_new_thread(coro):
|
||||
"""Run an async coroutine in a new thread with its own event loop"""
|
||||
|
||||
def run_in_thread():
|
||||
new_loop = asyncio.new_event_loop()
|
||||
asyncio.set_event_loop(new_loop)
|
||||
try:
|
||||
return new_loop.run_until_complete(coro)
|
||||
finally:
|
||||
new_loop.close()
|
||||
asyncio.set_event_loop(None)
|
||||
|
||||
with concurrent.futures.ThreadPoolExecutor() as executor:
|
||||
future = executor.submit(run_in_thread)
|
||||
return future.result()
|
||||
|
||||
|
||||
def run_async_safely(coro):
|
||||
"""Run an async coroutine safely, handling both cases where event loop exists or not"""
|
||||
try:
|
||||
# Check if we're already in a running event loop
|
||||
asyncio.get_running_loop()
|
||||
# If we get here, there's a running loop, so use thread approach
|
||||
return _run_async_in_new_thread(coro)
|
||||
except RuntimeError:
|
||||
# No running event loop, safe to use asyncio.run()
|
||||
return asyncio.run(coro)
|
@@ -18,7 +18,9 @@ class GameState:
|
||||
research: Optional[ResearchState] = field()
|
||||
timestamp: float = field(default_factory=time.time)
|
||||
namespaces: List[bytes] = field(default_factory=list)
|
||||
agent_messages: List[Dict[str, Any]] = field(default_factory=list)
|
||||
agent_messages: List[Any] = field(
|
||||
default_factory=list
|
||||
) # Can be List[Dict] or List[List[Dict]]
|
||||
|
||||
@property
|
||||
def is_multiagent(self) -> bool:
|
||||
@@ -28,24 +30,24 @@ class GameState:
|
||||
def num_agents(self) -> int:
|
||||
return len(self.inventories)
|
||||
|
||||
def parse_agent_messages(data: dict) -> List[Dict[str, Any]]:
|
||||
def parse_agent_messages(data: dict) -> List[Any]:
|
||||
agent_messages = data.get("agent_messages", [])
|
||||
if not isinstance(agent_messages, list):
|
||||
raise ValueError("agent_messages must be a list")
|
||||
if agent_messages and not all(isinstance(msg, dict) for msg in agent_messages):
|
||||
if agent_messages and not all(
|
||||
isinstance(msg, (dict, list)) for msg in agent_messages
|
||||
):
|
||||
for idx, message in enumerate(agent_messages):
|
||||
if isinstance(message, dict):
|
||||
continue
|
||||
elif isinstance(message, list):
|
||||
# Keep the list as-is, but validate its contents
|
||||
if len(message) > 0:
|
||||
if isinstance(message[0], dict):
|
||||
agent_messages[idx] = message[0]
|
||||
else:
|
||||
if not all(isinstance(msg, dict) for msg in message):
|
||||
raise ValueError(
|
||||
f"agent_messages[{idx}] must be a dictionary or a list of dictionaries, but got {type(message[0])}"
|
||||
f"agent_messages[{idx}] contains non-dictionary elements"
|
||||
)
|
||||
else:
|
||||
agent_messages[idx] = {}
|
||||
# Leave the list unchanged - don't convert to single dict
|
||||
else:
|
||||
raise ValueError(
|
||||
f"agent_messages[{idx}] must be a dictionary or a list of dictionaries, but got {type(message)}"
|
||||
|
1
fle/env/gym_env/config.py
vendored
1
fle/env/gym_env/config.py
vendored
@@ -14,7 +14,6 @@ class GymRunConfig:
|
||||
env_id: str # Gym environment ID from registry (e.g., "Factorio-iron_ore_throughput_16-v0")
|
||||
model: str
|
||||
version: Optional[int] = None
|
||||
num_agents: int = 1
|
||||
exit_on_task_success: bool = True
|
||||
observation_formatter: Optional[BasicObservationFormatter] = None
|
||||
|
||||
|
51
fle/env/gym_env/registry.py
vendored
51
fle/env/gym_env/registry.py
vendored
@@ -2,10 +2,12 @@ import os
|
||||
from dataclasses import dataclass
|
||||
from typing import Any, Dict, List, Optional
|
||||
|
||||
from fle.env.a2a_instance import A2AFactorioInstance
|
||||
import gym
|
||||
import json
|
||||
|
||||
from fle.commons.cluster_ips import get_local_container_ips
|
||||
from fle.commons.asyncio_utils import run_async_safely
|
||||
from fle.env import FactorioInstance
|
||||
from fle.env.gym_env.environment import FactorioGymEnv
|
||||
from fle.eval.tasks import TaskFactory
|
||||
@@ -52,11 +54,13 @@ class FactorioGymRegistry:
|
||||
with open(task_file, "r") as f:
|
||||
task_data = json.load(f)
|
||||
|
||||
task_key = task_data.get("config", {}).get("task_key", task_file.stem)
|
||||
task_type = task_data.get("task_type", "default")
|
||||
goal_description = task_data.get("config", {}).get(
|
||||
task_config = task_data.get("config", {})
|
||||
task_key = task_config.get("task_key", task_file.stem)
|
||||
task_type = task_config.get("task_type", "default")
|
||||
goal_description = task_config.get(
|
||||
"goal_description", f"Task: {task_key}"
|
||||
)
|
||||
num_agents = task_config.get("num_agents", 1)
|
||||
# Register the environment
|
||||
self.register_environment(
|
||||
env_id=task_key,
|
||||
@@ -64,6 +68,7 @@ class FactorioGymRegistry:
|
||||
task_config_path=str(task_file),
|
||||
description=goal_description,
|
||||
task_type=task_type,
|
||||
num_agents=num_agents,
|
||||
)
|
||||
|
||||
except Exception as e:
|
||||
@@ -134,32 +139,30 @@ def make_factorio_env(env_spec: GymEnvironmentSpec) -> FactorioGymEnv:
|
||||
# Create Factorio instance
|
||||
try:
|
||||
# Check for external server configuration via environment variables
|
||||
external_address = os.getenv("FACTORIO_SERVER_ADDRESS")
|
||||
external_port = os.getenv("FACTORIO_SERVER_PORT")
|
||||
address = os.getenv("FACTORIO_SERVER_ADDRESS")
|
||||
tcp_port = os.getenv("FACTORIO_SERVER_PORT")
|
||||
|
||||
if external_address and external_port:
|
||||
# Use external server
|
||||
instance = FactorioInstance(
|
||||
address=external_address,
|
||||
tcp_port=int(external_port),
|
||||
num_agents=env_spec.num_agents,
|
||||
)
|
||||
print(
|
||||
f"Using external Factorio server at {external_address}:{external_port}"
|
||||
)
|
||||
else:
|
||||
# Fall back to local containers
|
||||
if not address and not tcp_port:
|
||||
ips, udp_ports, tcp_ports = get_local_container_ips()
|
||||
if len(tcp_ports) == 0:
|
||||
raise RuntimeError("No Factorio containers available")
|
||||
address, tcp_port = ips[0], tcp_ports[0]
|
||||
|
||||
# Use the first available container
|
||||
instance = FactorioInstance(
|
||||
address=ips[0],
|
||||
tcp_port=tcp_ports[0],
|
||||
num_agents=env_spec.num_agents,
|
||||
)
|
||||
print(f"Using local Factorio container at {ips[0]}:{tcp_ports[0]}")
|
||||
common_kwargs = {
|
||||
"address": address,
|
||||
"tcp_port": int(tcp_port),
|
||||
"num_agents": env_spec.num_agents,
|
||||
"fast": True,
|
||||
"cache_scripts": True,
|
||||
"inventory": {},
|
||||
"all_technologies_researched": True,
|
||||
}
|
||||
|
||||
print(f"Using local Factorio container at {address}:{tcp_port}")
|
||||
if env_spec.num_agents > 1:
|
||||
instance = run_async_safely(A2AFactorioInstance.create(**common_kwargs))
|
||||
else:
|
||||
instance = FactorioInstance(**common_kwargs)
|
||||
|
||||
instance.speed(10)
|
||||
|
||||
|
14
fle/env/gym_env/run_eval.py
vendored
14
fle/env/gym_env/run_eval.py
vendored
@@ -28,15 +28,6 @@ def get_validated_run_configs(run_config_location: str) -> list[GymRunConfig]:
|
||||
run_configs_raw = json.load(f)
|
||||
run_configs = [GymRunConfig(**config) for config in run_configs_raw]
|
||||
|
||||
# Validate config
|
||||
num_agents_in_configs = [run_config.num_agents for run_config in run_configs]
|
||||
if any(num_agents == 1 for num_agents in num_agents_in_configs) and any(
|
||||
num_agents > 1 for num_agents in num_agents_in_configs
|
||||
):
|
||||
raise ValueError(
|
||||
"Cannot mix single agent and multi agent runs in the same run config file. Please split into separate files."
|
||||
)
|
||||
|
||||
# Validate that all environment IDs exist in the registry
|
||||
available_envs = list_available_environments()
|
||||
for run_config in run_configs:
|
||||
@@ -110,11 +101,10 @@ async def main():
|
||||
gym_env = gym.make(run_config.env_id)
|
||||
task = gym_env.unwrapped.task
|
||||
instance = gym_env.unwrapped.instance
|
||||
|
||||
# Create agents and their agent cards
|
||||
agents = []
|
||||
agent_cards = []
|
||||
for agent_idx in range(run_config.num_agents):
|
||||
for agent_idx in range(instance.num_agents):
|
||||
system_prompt = instance.get_system_prompt(agent_idx)
|
||||
agent = GymAgent(
|
||||
model=run_config.model,
|
||||
@@ -142,7 +132,7 @@ async def main():
|
||||
config = GymEvalConfig(
|
||||
agents=agents,
|
||||
version=version,
|
||||
version_description=f"model:{run_config.model}\ntype:{task.task_key}\nnum_agents:{run_config.num_agents}",
|
||||
version_description=f"model:{run_config.model}\ntype:{task.task_key}\nnum_agents:{instance.num_agents}",
|
||||
exit_on_task_success=run_config.exit_on_task_success,
|
||||
task=task,
|
||||
agent_cards=agent_cards,
|
||||
|
2
fle/env/instance.py
vendored
2
fle/env/instance.py
vendored
@@ -94,7 +94,7 @@ class FactorioInstance:
|
||||
def __init__(
|
||||
self,
|
||||
address=None,
|
||||
fast=False,
|
||||
fast=True,
|
||||
tcp_port=27000,
|
||||
inventory=None,
|
||||
cache_scripts=True,
|
||||
|
@@ -10,7 +10,8 @@
|
||||
"trajectory_length": 16,
|
||||
"holdout_wait_period": 60,
|
||||
"pre_holdout_wait_period": 60,
|
||||
"task_key": "iron_plate_throughput_unbounded_steps_show_steps_false",
|
||||
"task_key": "iron_plate_throughput_multiagent_distrust",
|
||||
"num_agents": 2,
|
||||
"show_number_of_steps_left_in_prompt": false}
|
||||
|
||||
}
|
@@ -6,7 +6,8 @@
|
||||
"trajectory_length": 16,
|
||||
"holdout_wait_period": 60,
|
||||
"pre_holdout_wait_period": 60,
|
||||
"task_key": "iron_plate_throughput_unbounded_steps_show_steps_false",
|
||||
"task_key": "iron_plate_throughput_multiagent_free",
|
||||
"num_agents": 2,
|
||||
"show_number_of_steps_left_in_prompt": false}
|
||||
|
||||
}
|
@@ -10,7 +10,8 @@
|
||||
"trajectory_length": 16,
|
||||
"holdout_wait_period": 60,
|
||||
"pre_holdout_wait_period": 60,
|
||||
"task_key": "iron_plate_throughput_unbounded_steps_show_steps_false",
|
||||
"task_key": "iron_plate_throughput_multiagent_impostor",
|
||||
"num_agents": 2,
|
||||
"show_number_of_steps_left_in_prompt": false}
|
||||
|
||||
}
|
@@ -26,6 +26,8 @@ class TaskFactory:
|
||||
}
|
||||
task_type = input_json["task_type"]
|
||||
task_config = input_json["config"]
|
||||
if "num_agents" in task_config:
|
||||
del task_config["num_agents"]
|
||||
if task_type in task_type_mapping:
|
||||
task_class = task_type_mapping[task_type]
|
||||
return task_class(**task_config)
|
||||
|
@@ -77,9 +77,6 @@ class UnboundedThroughputTask(TaskABC):
|
||||
success=False,
|
||||
meta={
|
||||
"achievements": max_achievements,
|
||||
"nr_of_steps_left": self.trajectory_length
|
||||
- step_statistics["current_step_id"]
|
||||
- 1,
|
||||
},
|
||||
)
|
||||
|
||||
|
Reference in New Issue
Block a user