mirror of
https://github.com/JackHopkins/factorio-learning-environment.git
synced 2025-09-06 13:23:58 +00:00
sweep id, open-router, 401 debugging
This commit is contained in:
@@ -119,13 +119,14 @@ class GymAgent(AgentABC):
|
||||
agent_idx: Optional[int] = None,
|
||||
observation_formatter: Optional[BasicObservationFormatter] = None,
|
||||
system_prompt_formatter: Optional[SystemPromptFormatter] = None,
|
||||
api_key_config_file: Optional[str] = None,
|
||||
*args,
|
||||
**kwargs,
|
||||
):
|
||||
instructions = self._get_instructions(system_prompt, task, agent_idx)
|
||||
super().__init__(model, instructions, *args, **kwargs)
|
||||
self.task = task
|
||||
self.api_factory = APIFactory(model)
|
||||
self.api_factory = APIFactory(model, api_key_config_file=api_key_config_file)
|
||||
self.observation_formatter = (
|
||||
observation_formatter or BasicObservationFormatter()
|
||||
)
|
||||
|
@@ -10,12 +10,19 @@ from fle.agents.llm.utils import (
|
||||
remove_whitespace_blocks,
|
||||
)
|
||||
|
||||
try:
|
||||
from fle.eval.analysis.api_key_manager import get_api_key_manager
|
||||
|
||||
API_KEY_MANAGER_AVAILABLE = True
|
||||
except ImportError:
|
||||
API_KEY_MANAGER_AVAILABLE = False
|
||||
# Lazy import to avoid circular dependencies
|
||||
def _get_api_key_manager():
|
||||
"""Lazy import for API key manager to avoid circular imports."""
|
||||
try:
|
||||
from fle.eval.analysis.api_key_manager import get_api_key_manager
|
||||
|
||||
return get_api_key_manager
|
||||
except ImportError:
|
||||
return None
|
||||
|
||||
|
||||
API_KEY_MANAGER_AVAILABLE = True # Assume available, handle at runtime
|
||||
|
||||
|
||||
class APIFactory:
|
||||
@@ -64,28 +71,14 @@ class APIFactory:
|
||||
"""
|
||||
self.model = model
|
||||
self.beam = beam
|
||||
self.api_key_config_file = (
|
||||
api_key_config_file # Store for child process reinitialization
|
||||
)
|
||||
self.api_key_manager = None
|
||||
|
||||
# Initialize API key manager if available
|
||||
if API_KEY_MANAGER_AVAILABLE:
|
||||
try:
|
||||
# Check for config file from environment if not provided
|
||||
config_file = (
|
||||
api_key_config_file
|
||||
or os.getenv("FLE_API_KEY_CONFIG_FILE")
|
||||
or os.getenv("API_KEY_CONFIG_FILE")
|
||||
)
|
||||
|
||||
self.api_key_manager = get_api_key_manager(config_file)
|
||||
|
||||
if config_file and self.api_key_manager:
|
||||
logging.info(
|
||||
f"Initialized API key manager with config: {config_file}"
|
||||
)
|
||||
|
||||
except Exception as e:
|
||||
logging.warning(f"Failed to initialize API key manager: {e}")
|
||||
self.api_key_manager = None
|
||||
# Don't initialize API key manager in parent process to avoid multiprocessing pickle issues
|
||||
# It will be initialized lazily in the child process when first needed
|
||||
self.api_key_manager = None
|
||||
|
||||
def _get_provider_config(self, model: str) -> dict:
|
||||
"""Get provider config based on model name"""
|
||||
@@ -103,14 +96,35 @@ class APIFactory:
|
||||
Returns:
|
||||
API key string
|
||||
"""
|
||||
# Try key manager first if available
|
||||
if self.api_key_manager and "key_manager_provider" in provider_config:
|
||||
# Try key manager first if available (reinitialize if needed due to multiprocessing)
|
||||
if "key_manager_provider" in provider_config:
|
||||
key_manager_provider = provider_config["key_manager_provider"]
|
||||
rotated_key = self.api_key_manager.get_key(key_manager_provider)
|
||||
|
||||
if rotated_key:
|
||||
logging.debug(f"Using rotated key for {key_manager_provider}")
|
||||
return rotated_key
|
||||
# Reinitialize API key manager in child process if needed
|
||||
if not self.api_key_manager:
|
||||
try:
|
||||
config_file = (
|
||||
self.api_key_config_file
|
||||
or os.getenv("FLE_API_KEY_CONFIG_FILE")
|
||||
or os.getenv("API_KEY_CONFIG_FILE")
|
||||
)
|
||||
if config_file:
|
||||
get_api_key_manager_func = _get_api_key_manager()
|
||||
if get_api_key_manager_func:
|
||||
self.api_key_manager = get_api_key_manager_func(config_file)
|
||||
logging.info(
|
||||
f"Reinitialized API key manager in child process: {config_file}"
|
||||
)
|
||||
except Exception as e:
|
||||
logging.warning(
|
||||
f"Failed to reinitialize API key manager in child process: {e}"
|
||||
)
|
||||
|
||||
if self.api_key_manager:
|
||||
rotated_key = self.api_key_manager.get_key(key_manager_provider)
|
||||
if rotated_key:
|
||||
logging.debug(f"Using rotated key for {key_manager_provider}")
|
||||
return rotated_key
|
||||
|
||||
# Fallback to environment variable
|
||||
env_var = provider_config["api_key_env"]
|
||||
|
@@ -241,18 +241,17 @@ class PythonParser:
|
||||
|
||||
|
||||
def parse_response(response) -> Optional[Policy]:
|
||||
has_usage = hasattr(response, "usage")
|
||||
prompt_tokens = has_usage and hasattr(response.usage, "prompt_tokens")
|
||||
completion_tokens = has_usage and hasattr(response.usage, "completion_tokens")
|
||||
if hasattr(response, "choices"):
|
||||
choice = response.choices[0]
|
||||
input_tokens = response.usage.prompt_tokens if hasattr(response, "usage") else 0
|
||||
output_tokens = (
|
||||
response.usage.completion_tokens if hasattr(response, "usage") else 0
|
||||
)
|
||||
input_tokens = response.usage.prompt_tokens if prompt_tokens else 0
|
||||
output_tokens = response.usage.completion_tokens if completion_tokens else 0
|
||||
else:
|
||||
choice = response.content[0]
|
||||
input_tokens = response.usage.input_tokens if hasattr(response, "usage") else 0
|
||||
output_tokens = (
|
||||
response.usage.output_tokens if hasattr(response, "usage") else 0
|
||||
)
|
||||
input_tokens = response.usage.input_tokens if prompt_tokens else 0
|
||||
output_tokens = response.usage.output_tokens if completion_tokens else 0
|
||||
|
||||
total_tokens = input_tokens + output_tokens
|
||||
try:
|
||||
|
18
fle/env/gym_env/run_eval.py
vendored
18
fle/env/gym_env/run_eval.py
vendored
@@ -83,16 +83,26 @@ async def run_trajectory(run_idx: int, config: GymEvalConfig):
|
||||
config.version_description.split("model:")[1].split("\n")[0].strip()
|
||||
)
|
||||
|
||||
# Get sweep ID for tagging
|
||||
sweep_id = os.getenv("FLE_SWEEP_ID", "unknown_sweep")
|
||||
|
||||
wandb_logger = WandBLogger(
|
||||
project=os.getenv("WANDB_PROJECT", "factorio-learning-environment"),
|
||||
run_name=f"{model_name}-{task_name}-v{config.version}-trial{run_idx}",
|
||||
tags=["gym_eval", model_name, task_name, f"v{config.version}"],
|
||||
tags=[
|
||||
"gym_eval",
|
||||
model_name,
|
||||
task_name,
|
||||
f"v{config.version}",
|
||||
f"sweep:{sweep_id}",
|
||||
],
|
||||
config={
|
||||
"model": model_name,
|
||||
"task": task_name,
|
||||
"version": config.version,
|
||||
"trial": run_idx,
|
||||
"version_description": config.version_description,
|
||||
"sweep_id": sweep_id,
|
||||
},
|
||||
)
|
||||
except Exception as e:
|
||||
@@ -140,6 +150,11 @@ async def main(config_path):
|
||||
system_prompt = generator.generate_for_agent(
|
||||
agent_idx=agent_idx, num_agents=num_agents
|
||||
)
|
||||
# Get API key config file from environment (set by sweep_manager)
|
||||
api_key_config_file = os.getenv("FLE_API_KEY_CONFIG_FILE") or os.getenv(
|
||||
"API_KEY_CONFIG_FILE"
|
||||
)
|
||||
|
||||
agent = GymAgent(
|
||||
model=run_config.model,
|
||||
system_prompt=system_prompt,
|
||||
@@ -147,6 +162,7 @@ async def main(config_path):
|
||||
agent_idx=agent_idx,
|
||||
observation_formatter=BasicObservationFormatter(include_research=False),
|
||||
system_prompt_formatter=SystemPromptFormatter(),
|
||||
api_key_config_file=api_key_config_file,
|
||||
)
|
||||
agents.append(agent)
|
||||
|
||||
|
2
fle/env/gym_env/trajectory_runner.py
vendored
2
fle/env/gym_env/trajectory_runner.py
vendored
@@ -1,3 +1,4 @@
|
||||
import os
|
||||
import time
|
||||
from itertools import product
|
||||
from typing import Any, List, Dict, Optional, Tuple
|
||||
@@ -157,6 +158,7 @@ class GymTrajectoryRunner:
|
||||
"model": self.agents[agent_idx].model,
|
||||
"process_id": self.process_id,
|
||||
"error_occurred": error_occurred,
|
||||
"sweep_id": os.getenv("FLE_SWEEP_ID", "unknown"),
|
||||
},
|
||||
depth=depth,
|
||||
)
|
||||
|
@@ -38,6 +38,7 @@ from .api_key_manager import (
|
||||
get_api_key_manager,
|
||||
create_api_keys_config_template,
|
||||
)
|
||||
from .server_manager import ServerManager, get_server_manager
|
||||
|
||||
# Utility functions
|
||||
from .analysis_utils import (
|
||||
@@ -56,6 +57,7 @@ __all__ = [
|
||||
"SweepConfig",
|
||||
"ResultsVisualizer",
|
||||
"APIKeyManager",
|
||||
"ServerManager",
|
||||
# Utility functions
|
||||
"group_results_by_model",
|
||||
"group_results_by_task",
|
||||
@@ -63,4 +65,5 @@ __all__ = [
|
||||
"get_trajectory_summary",
|
||||
"get_api_key_manager",
|
||||
"create_api_keys_config_template",
|
||||
"get_server_manager",
|
||||
]
|
||||
|
@@ -16,7 +16,11 @@ async def run_small_sweep():
|
||||
name="test_sweep_small",
|
||||
description="Small test sweep with 2 Claude models and 2 tasks",
|
||||
# Models to evaluate
|
||||
models=["claude-sonnet-4-20250514", "claude-opus-4-20250514"],
|
||||
# models=["claude-sonnet-4-20250514", "claude-opus-4-20250514"],
|
||||
models=[
|
||||
"open-router-anthropic/claude-sonnet-4",
|
||||
"open-router-anthropic/claude-opus-4",
|
||||
],
|
||||
# Tasks to evaluate (these should be valid gym environment IDs)
|
||||
tasks=["iron_ore_throughput", "iron_plate_throughput"],
|
||||
# Pass@8 evaluation (3 trials per model-task combination)
|
||||
@@ -34,6 +38,8 @@ async def run_small_sweep():
|
||||
# Retry configuration
|
||||
retry_failed_runs=True,
|
||||
max_retries=2,
|
||||
# Use API key configuration file
|
||||
api_key_config_file="api_keys.json",
|
||||
)
|
||||
|
||||
# Create and run sweep
|
||||
@@ -104,6 +110,106 @@ async def run_large_production_sweep():
|
||||
return results
|
||||
|
||||
|
||||
async def run_openrouter_grok_sweep():
|
||||
"""Example: Test Grok and other models via OpenRouter"""
|
||||
|
||||
config = SweepConfig(
|
||||
name="openrouter_grok_evaluation",
|
||||
description="Evaluate Grok Beta and other OpenRouter models on Factorio tasks",
|
||||
# OpenRouter models (requires OPEN_ROUTER_API_KEY environment variable)
|
||||
models=[
|
||||
"open-router-x-ai/grok-beta", # Grok Beta via OpenRouter
|
||||
"open-router-anthropic/claude-3.5-sonnet", # Claude via OpenRouter
|
||||
"open-router-openai/gpt-4o", # GPT-4o via OpenRouter
|
||||
# Compare with direct API models
|
||||
"claude-sonnet-4-20250514", # Direct Anthropic API
|
||||
"gpt-4o", # Direct OpenAI API
|
||||
],
|
||||
tasks=[
|
||||
"iron_ore_throughput",
|
||||
"iron_plate_throughput",
|
||||
],
|
||||
num_trials_per_config=3, # Start with fewer trials for testing
|
||||
max_concurrent_processes=3, # Adjust based on your server count
|
||||
# Enable tracking
|
||||
enable_wandb=True,
|
||||
wandb_project="openrouter-grok-evaluation",
|
||||
# Configuration
|
||||
output_dir="sweep_results/openrouter_grok_evaluation",
|
||||
log_interval_minutes=5,
|
||||
retry_failed_runs=True,
|
||||
max_retries=2,
|
||||
)
|
||||
|
||||
# Create and run sweep
|
||||
manager = SweepManager(config)
|
||||
results = await manager.run_sweep()
|
||||
|
||||
print("OpenRouter Grok sweep completed!")
|
||||
print(f"Models tested: {', '.join(config.models)}")
|
||||
print(f"Results saved to: {config.output_dir}")
|
||||
|
||||
return results
|
||||
|
||||
|
||||
async def run_openrouter_model_tournament():
|
||||
"""Example: Large tournament of models via OpenRouter"""
|
||||
|
||||
config = SweepConfig(
|
||||
name="openrouter_model_tournament",
|
||||
description="Tournament of top models across multiple providers via OpenRouter",
|
||||
# Comprehensive model comparison via OpenRouter
|
||||
models=[
|
||||
# xAI
|
||||
"open-router-x-ai/grok-beta",
|
||||
# Anthropic
|
||||
"open-router-anthropic/claude-3.5-sonnet",
|
||||
"open-router-anthropic/claude-3-opus",
|
||||
"open-router-anthropic/claude-3-haiku",
|
||||
# OpenAI
|
||||
"open-router-openai/gpt-4o",
|
||||
"open-router-openai/gpt-4o-mini",
|
||||
"open-router-openai/gpt-4-turbo",
|
||||
# Meta Llama
|
||||
"open-router-meta-llama/llama-3.1-405b-instruct",
|
||||
"open-router-meta-llama/llama-3.1-70b-instruct",
|
||||
# Google
|
||||
"open-router-google/gemini-pro-1.5",
|
||||
# Other providers
|
||||
"open-router-mistralai/mistral-large",
|
||||
],
|
||||
tasks=[
|
||||
"iron_ore_throughput",
|
||||
"iron_plate_throughput",
|
||||
"gear_production",
|
||||
],
|
||||
num_trials_per_config=5, # Reasonable for many models
|
||||
max_concurrent_processes=4,
|
||||
# Tracking and storage
|
||||
enable_wandb=True,
|
||||
wandb_project="openrouter-model-tournament",
|
||||
output_dir="sweep_results/openrouter_model_tournament",
|
||||
log_interval_minutes=15, # Less frequent logging for long runs
|
||||
retry_failed_runs=True,
|
||||
max_retries=3,
|
||||
# Use API key manager for multiple keys if available
|
||||
api_key_config_file="api_keys_config.json",
|
||||
)
|
||||
|
||||
# Create and run sweep
|
||||
manager = SweepManager(config)
|
||||
results = await manager.run_sweep()
|
||||
|
||||
print("OpenRouter model tournament completed!")
|
||||
print(f"Models tested: {len(config.models)}")
|
||||
print(
|
||||
f"Total evaluations: {len(config.models) * len(config.tasks) * config.num_trials_per_config}"
|
||||
)
|
||||
print(f"Results saved to: {config.output_dir}")
|
||||
|
||||
return results
|
||||
|
||||
|
||||
def create_custom_sweep_config(
|
||||
models: list, tasks: list, trials: int = 8, name: str = "custom_sweep"
|
||||
) -> SweepConfig:
|
||||
@@ -144,9 +250,26 @@ def create_custom_sweep_config(
|
||||
if __name__ == "__main__":
|
||||
import sys
|
||||
|
||||
if len(sys.argv) > 1 and sys.argv[1] == "large":
|
||||
print("Running large production sweep...")
|
||||
asyncio.run(run_large_production_sweep())
|
||||
if len(sys.argv) > 1:
|
||||
sweep_type = sys.argv[1]
|
||||
|
||||
if sweep_type == "large":
|
||||
print("Running large production sweep...")
|
||||
asyncio.run(run_large_production_sweep())
|
||||
elif sweep_type == "grok":
|
||||
print("Running OpenRouter Grok evaluation...")
|
||||
print("Make sure you have set OPEN_ROUTER_API_KEY environment variable!")
|
||||
asyncio.run(run_openrouter_grok_sweep())
|
||||
elif sweep_type == "tournament":
|
||||
print("Running OpenRouter model tournament...")
|
||||
print("Make sure you have set OPEN_ROUTER_API_KEY environment variable!")
|
||||
asyncio.run(run_openrouter_model_tournament())
|
||||
else:
|
||||
print(f"Unknown sweep type: {sweep_type}")
|
||||
print("Available options: small (default), large, grok, tournament")
|
||||
sys.exit(1)
|
||||
else:
|
||||
print("Running small test sweep...")
|
||||
asyncio.run(run_small_sweep())
|
||||
|
||||
print("\n🎉 Sweep completed! Check WandB and the output directory for results.")
|
||||
|
@@ -7,6 +7,7 @@ import json
|
||||
import multiprocessing
|
||||
import os
|
||||
import time
|
||||
import uuid
|
||||
from dataclasses import dataclass, field
|
||||
from typing import List, Dict, Any, Optional
|
||||
from pathlib import Path
|
||||
@@ -19,6 +20,7 @@ from fle.commons.db_client import get_next_version
|
||||
from .database_analyzer import DatabaseAnalyzer
|
||||
from .performance_metrics import PerformanceAnalyzer
|
||||
from .wandb_logger import WandBSweepLogger
|
||||
from .server_manager import get_server_manager, ServerManager
|
||||
|
||||
|
||||
@dataclass
|
||||
@@ -76,6 +78,7 @@ class RunJob:
|
||||
model: str
|
||||
task: str
|
||||
trial_number: int
|
||||
sweep_id: str
|
||||
version: Optional[int] = None
|
||||
status: str = "pending" # pending, running, completed, failed
|
||||
start_time: Optional[datetime] = None
|
||||
@@ -94,20 +97,48 @@ class SweepManager:
|
||||
config: SweepConfig with sweep parameters
|
||||
"""
|
||||
self.config = config
|
||||
|
||||
# Generate unique sweep ID with timestamp and short UUID
|
||||
timestamp = datetime.now().strftime("%Y%m%d_%H%M%S")
|
||||
short_uuid = str(uuid.uuid4())[:8]
|
||||
self.sweep_id = f"{config.name}_{timestamp}_{short_uuid}"
|
||||
|
||||
self.jobs: List[RunJob] = []
|
||||
self.active_processes: Dict[str, multiprocessing.Process] = {}
|
||||
self.completed_versions: List[int] = []
|
||||
self.start_time: Optional[datetime] = None
|
||||
self.wandb_logger: Optional[WandBSweepLogger] = None
|
||||
self.database_analyzer: Optional[DatabaseAnalyzer] = None
|
||||
self.server_manager: Optional[ServerManager] = None
|
||||
|
||||
if config.output_dir:
|
||||
Path(config.output_dir).mkdir(parents=True, exist_ok=True)
|
||||
|
||||
# Initialize server manager
|
||||
self.server_manager = get_server_manager()
|
||||
total_servers = self.server_manager.get_total_server_count()
|
||||
|
||||
# Warn if we don't have enough servers for max concurrent processes
|
||||
if total_servers < self.config.max_concurrent_processes:
|
||||
print(
|
||||
f"⚠️ Warning: Only {total_servers} Factorio servers available, "
|
||||
f"but max_concurrent_processes={self.config.max_concurrent_processes}"
|
||||
)
|
||||
print(f" Concurrent processes will be limited to {total_servers}")
|
||||
# Adjust max concurrent processes to available servers
|
||||
self.config.max_concurrent_processes = min(
|
||||
self.config.max_concurrent_processes, total_servers
|
||||
)
|
||||
|
||||
# Initialize WandB if enabled
|
||||
if config.enable_wandb:
|
||||
self.wandb_logger = WandBSweepLogger(config.wandb_project)
|
||||
|
||||
print(f"🆔 Sweep ID: {self.sweep_id}")
|
||||
print(
|
||||
f"🖥️ Available servers: {total_servers}, Max concurrent: {self.config.max_concurrent_processes}"
|
||||
)
|
||||
|
||||
def generate_jobs(self) -> List[RunJob]:
|
||||
"""Generate all run jobs for the sweep
|
||||
|
||||
@@ -120,7 +151,13 @@ class SweepManager:
|
||||
for model, task in itertools.product(self.config.models, self.config.tasks):
|
||||
for trial in range(self.config.num_trials_per_config):
|
||||
job_id = f"{model}_{task}_trial{trial:02d}"
|
||||
job = RunJob(job_id=job_id, model=model, task=task, trial_number=trial)
|
||||
job = RunJob(
|
||||
job_id=job_id,
|
||||
model=model,
|
||||
task=task,
|
||||
trial_number=trial,
|
||||
sweep_id=self.sweep_id,
|
||||
)
|
||||
jobs.append(job)
|
||||
|
||||
# Shuffle if requested to distribute load
|
||||
@@ -214,6 +251,21 @@ class SweepManager:
|
||||
"""
|
||||
print(f"Starting job: {job.job_id} (version {job.version})")
|
||||
|
||||
# Allocate a server for this job
|
||||
server_allocation = self.server_manager.allocate_server(job.job_id)
|
||||
if not server_allocation:
|
||||
job.status = "failed"
|
||||
job.error_message = "No Factorio servers available"
|
||||
job.end_time = datetime.now()
|
||||
print(f"❌ Failed to start job {job.job_id}: No servers available")
|
||||
return
|
||||
|
||||
print(
|
||||
f"🖥️ Allocated server {server_allocation.server_id} "
|
||||
f"({server_allocation.server_address}:{server_allocation.tcp_port}) "
|
||||
f"to job {job.job_id}"
|
||||
)
|
||||
|
||||
# Create run configuration
|
||||
run_config = GymRunConfig(env_id=job.task, model=job.model, version=job.version)
|
||||
|
||||
@@ -229,10 +281,13 @@ class SweepManager:
|
||||
job.job_id,
|
||||
run_config,
|
||||
job.version,
|
||||
job.sweep_id,
|
||||
server_allocation,
|
||||
self.config.api_key_config_file,
|
||||
),
|
||||
)
|
||||
process.start()
|
||||
server_allocation.process_id = process.pid
|
||||
self.active_processes[job.job_id] = process
|
||||
|
||||
# Log to WandB if enabled
|
||||
@@ -240,19 +295,30 @@ class SweepManager:
|
||||
logger = self.wandb_logger.create_run_logger(
|
||||
job.job_id, job.model, job.task, job.version
|
||||
)
|
||||
logger.log_metrics({"job/status": "started"})
|
||||
logger.log_metrics(
|
||||
{
|
||||
"job/status": "started",
|
||||
"job/server_id": server_allocation.server_id,
|
||||
"job/server_address": server_allocation.server_address,
|
||||
"job/server_port": server_allocation.tcp_port,
|
||||
}
|
||||
)
|
||||
|
||||
except Exception as e:
|
||||
# Release server if process failed to start
|
||||
self.server_manager.release_server(job.job_id)
|
||||
job.status = "failed"
|
||||
job.error_message = str(e)
|
||||
job.end_time = datetime.now()
|
||||
print(f"Failed to start job {job.job_id}: {e}")
|
||||
print(f"❌ Failed to start job {job.job_id}: {e}")
|
||||
|
||||
@staticmethod
|
||||
def run_job_wrapper(
|
||||
job_id: str,
|
||||
run_config: GymRunConfig,
|
||||
version: int,
|
||||
sweep_id: str,
|
||||
server_allocation, # ServerAllocation object
|
||||
api_key_config_file: Optional[str] = None,
|
||||
):
|
||||
"""Wrapper for running a job in a subprocess
|
||||
@@ -261,16 +327,32 @@ class SweepManager:
|
||||
job_id: Unique job identifier
|
||||
run_config: GymRunConfig for this job
|
||||
version: Version number for this job
|
||||
sweep_id: Unique identifier for this sweep
|
||||
server_allocation: ServerAllocation object with server details
|
||||
api_key_config_file: Optional path to API key config file
|
||||
"""
|
||||
try:
|
||||
# Set environment variable for API key config if provided
|
||||
# Set environment variables
|
||||
if api_key_config_file:
|
||||
os.environ["FLE_API_KEY_CONFIG_FILE"] = api_key_config_file
|
||||
|
||||
# Set sweep ID environment variable for database and WandB logging
|
||||
os.environ["FLE_SWEEP_ID"] = sweep_id
|
||||
|
||||
# Set server allocation environment variables
|
||||
os.environ["FACTORIO_SERVER_ADDRESS"] = server_allocation.server_address
|
||||
os.environ["FACTORIO_SERVER_PORT"] = str(server_allocation.tcp_port)
|
||||
os.environ["FLE_SERVER_ID"] = str(server_allocation.server_id)
|
||||
|
||||
# Override PORT_OFFSET to ensure we use the allocated server
|
||||
os.environ["PORT_OFFSET"] = str(server_allocation.server_id)
|
||||
|
||||
# This would be similar to the run_process function in run_eval.py
|
||||
# but adapted for single configurations
|
||||
print(f"Executing job {job_id} with version {version}")
|
||||
print(
|
||||
f"Executing job {job_id} with version {version} (sweep: {sweep_id}) "
|
||||
f"on server {server_allocation.server_id} ({server_allocation.server_address}:{server_allocation.tcp_port})"
|
||||
)
|
||||
|
||||
# Create a temporary config file for this job
|
||||
config_data = [run_config.__dict__]
|
||||
@@ -320,10 +402,17 @@ class SweepManager:
|
||||
|
||||
job.end_time = datetime.now()
|
||||
|
||||
# Release the server allocated to this job
|
||||
server_released = self.server_manager.release_server(job_id)
|
||||
if server_released:
|
||||
print(f"🔓 Released server for completed job {job_id}")
|
||||
else:
|
||||
print(f"⚠️ No server allocation found for job {job_id}")
|
||||
|
||||
if process.exitcode == 0:
|
||||
job.status = "completed"
|
||||
self.completed_versions.append(job.version)
|
||||
print(f"Job {job_id} completed successfully")
|
||||
print(f"✅ Job {job_id} completed successfully")
|
||||
|
||||
# Log completion to WandB
|
||||
if self.wandb_logger:
|
||||
@@ -334,7 +423,7 @@ class SweepManager:
|
||||
else:
|
||||
job.status = "failed"
|
||||
job.error_message = f"Process exited with code {process.exitcode}"
|
||||
print(f"Job {job_id} failed with exit code {process.exitcode}")
|
||||
print(f"❌ Job {job_id} failed with exit code {process.exitcode}")
|
||||
|
||||
# Retry if configured and retries available
|
||||
if (
|
||||
@@ -342,13 +431,14 @@ class SweepManager:
|
||||
and job.retry_count < self.config.max_retries
|
||||
):
|
||||
print(
|
||||
f"Retrying job {job_id} (attempt {job.retry_count + 1}/{self.config.max_retries})"
|
||||
f"🔄 Retrying job {job_id} (attempt {job.retry_count + 1}/{self.config.max_retries})"
|
||||
)
|
||||
job.retry_count += 1
|
||||
job.status = "pending"
|
||||
job.start_time = None
|
||||
job.end_time = None
|
||||
job.error_message = None
|
||||
# Note: Server will be reallocated when job is retried
|
||||
|
||||
async def log_progress_if_needed(self):
|
||||
"""Log progress summary if enough time has elapsed"""
|
||||
@@ -372,6 +462,9 @@ class SweepManager:
|
||||
|
||||
elapsed_time = datetime.now() - self.start_time
|
||||
|
||||
# Get server allocation status
|
||||
server_status = self.server_manager.get_allocation_status()
|
||||
|
||||
print("\n=== Sweep Progress ===")
|
||||
print(f"Total jobs: {total_jobs}")
|
||||
print(f"Completed: {completed} ({completed / total_jobs * 100:.1f}%)")
|
||||
@@ -380,11 +473,16 @@ class SweepManager:
|
||||
print(f"Pending: {pending}")
|
||||
print(f"Elapsed time: {elapsed_time}")
|
||||
|
||||
print("\n🖥️ Server Status:")
|
||||
print(f"Total servers: {server_status['total_servers']}")
|
||||
print(f"Available: {server_status['available_servers']}")
|
||||
print(f"Allocated: {server_status['allocated_servers']}")
|
||||
|
||||
if completed > 0:
|
||||
avg_time_per_job = elapsed_time.total_seconds() / completed
|
||||
eta_seconds = avg_time_per_job * pending
|
||||
eta_hours = eta_seconds / 3600
|
||||
print(f"ETA: {eta_hours:.1f} hours")
|
||||
print(f"\n⏱️ ETA: {eta_hours:.1f} hours")
|
||||
|
||||
# Analyze recent results if we have completed jobs
|
||||
if self.completed_versions and self.database_analyzer:
|
||||
@@ -507,6 +605,7 @@ class SweepManager:
|
||||
|
||||
# Generate comprehensive report
|
||||
report = {
|
||||
"sweep_id": self.sweep_id,
|
||||
"sweep_config": self.config.__dict__,
|
||||
"execution_summary": {
|
||||
"total_jobs": len(self.jobs),
|
||||
|
Reference in New Issue
Block a user