other files

This commit is contained in:
Neel Kant
2025-09-04 13:08:04 -07:00
parent afe6ee00cd
commit 233e3d1384
5 changed files with 751 additions and 17 deletions

View File

@@ -165,6 +165,38 @@ manager = SweepManager(config)
results = await manager.run_sweep()
```
#### Resuming Failed Sweeps
The SweepManager now supports resuming failed sweeps without duplicating completed runs:
```python
# Resume an existing sweep
existing_sweep_id = "my_experiment_20241201_120000_abcd1234"
manager = SweepManager(config, existing_sweep_id=existing_sweep_id)
# Alternative: using class method
manager = SweepManager.resume_sweep(config, existing_sweep_id)
# The manager will automatically:
# - Skip completed runs
# - Retry partial/failed runs
# - Continue with remaining jobs
results = await manager.run_sweep()
```
Enhanced WandB metadata for resumed sweeps includes:
- `sweep_id`: Unique identifier for the sweep
- `is_resume`: Boolean indicating if this is a resumed sweep
- `completion_status`: Track run completion ("running", "successful", "failed_final", "will_retry")
- `retry_count`: Number of retry attempts for each job
- Tags: Include `sweep:{sweep_id}` for easy filtering
Filter in WandB using:
- `config.sweep_id = "your_sweep_id"`
- `config.completion_status = "successful"`
- `tags: sweep:your_sweep_id`
### ResultsVisualizer
Generate analysis plots:
@@ -197,6 +229,7 @@ See the `examples/` directory for complete usage examples:
- `example_sweep_config.py`: Example sweep configurations
- `analyze_sweep_results.py`: Analysis script with various commands
- `resume_sweep_example.py`: Example of resuming failed sweeps
### Running Examples

View File

@@ -132,7 +132,8 @@ class DatabaseAnalyzer:
THEN achievements_json
ELSE NULL END,
'; '
) as all_achievements
) as all_achievements,
STRING_AGG(DISTINCT meta::text, '; ') as meta_aggregated
FROM programs
WHERE version IN ({version_list})
GROUP BY version, version_description, model, instance
@@ -147,6 +148,53 @@ class DatabaseAnalyzer:
results = await self.db_client.execute_query(query)
return pd.DataFrame(results)
async def get_trajectory_summaries_by_sweep(self, sweep_id: str) -> pd.DataFrame:
"""Get trajectory-level summaries for a specific sweep
Args:
sweep_id: Sweep identifier to filter by
Returns:
DataFrame with one row per trajectory (version + instance)
"""
await self.ensure_connection()
query = """
WITH trajectory_stats AS (
SELECT
version,
version_description,
model,
instance,
MAX(depth) as max_depth,
MAX(value) as final_reward,
MAX(raw_reward) as max_raw_reward,
COUNT(*) as num_steps,
SUM(token_usage) as total_tokens,
SUM(completion_token_usage) as total_completion_tokens,
MIN(created_at) as start_time,
MAX(created_at) as end_time,
STRING_AGG(
CASE WHEN achievements_json != '{{}}' AND achievements_json IS NOT NULL
THEN achievements_json
ELSE NULL END,
'; '
) as all_achievements,
STRING_AGG(DISTINCT meta::text, '; ') as meta_aggregated
FROM programs
WHERE meta->>'sweep_id' = %s
GROUP BY version, version_description, model, instance
)
SELECT
*,
EXTRACT(EPOCH FROM (end_time - start_time)) as duration_seconds
FROM trajectory_stats
ORDER BY version, instance
"""
results = await self.db_client.execute_query(query, (sweep_id,))
return pd.DataFrame(results)
async def get_model_comparison(
self,
models: List[str],

View File

@@ -0,0 +1,63 @@
"""
Example demonstrating how to resume a failed sweep.
This example shows how to restart a sweep without duplicating completed runs.
"""
from fle.eval.analysis.sweep_manager import SweepManager, SweepConfig
async def resume_sweep_example():
"""Example of resuming a failed sweep"""
# Create the same config as the original sweep
config = SweepConfig(
name="my_experiment",
models=["gpt-4", "claude-3-sonnet"],
tasks=["craft_iron_plate", "build_furnace", "mine_coal"],
num_trials_per_config=8,
max_concurrent_processes=4,
api_key_config_file="path/to/api_keys.json",
)
# Resume sweep with existing sweep ID
# You can get the sweep ID from the original run logs or WandB
existing_sweep_id = "my_experiment_20241201_120000_abcd1234"
# Method 1: Using the constructor
sweep_manager = SweepManager(config, existing_sweep_id=existing_sweep_id)
# Method 2: Using the class method (equivalent)
# sweep_manager = SweepManager.resume_sweep(config, existing_sweep_id)
print("Starting sweep resume...")
results = await sweep_manager.run_sweep()
print(f"Sweep completed! Results: {results}")
async def new_sweep_example():
"""Example of starting a new sweep (for comparison)"""
config = SweepConfig(
name="my_new_experiment",
models=["gpt-4", "claude-3-sonnet"],
tasks=["craft_iron_plate", "build_furnace"],
num_trials_per_config=4,
)
# Create new sweep (no existing_sweep_id)
sweep_manager = SweepManager(config)
await sweep_manager.run_sweep()
print(f"New sweep completed! Sweep ID: {sweep_manager.sweep_id}")
if __name__ == "__main__":
# To resume a sweep:
# asyncio.run(resume_sweep_example())
# To start a new sweep:
# asyncio.run(new_sweep_example())
print("Example script - uncomment the appropriate line above to run")

View File

@@ -0,0 +1,338 @@
"""
Server management for concurrent evaluations.
Handles allocation and tracking of Factorio server instances across
multiple concurrent evaluation processes.
"""
import threading
from dataclasses import dataclass
from typing import List, Dict, Optional, Set
from datetime import datetime, timedelta
from fle.commons.cluster_ips import get_local_container_ips
@dataclass
class ServerAllocation:
"""Tracks allocation of a Factorio server to a job"""
server_id: int
server_address: str
tcp_port: int
udp_port: int
job_id: str
allocated_at: datetime
process_id: Optional[int] = None
def to_dict(self) -> Dict:
"""Convert to dictionary for logging"""
return {
"server_id": self.server_id,
"server_address": self.server_address,
"tcp_port": self.tcp_port,
"udp_port": self.udp_port,
"job_id": self.job_id,
"allocated_at": self.allocated_at.isoformat(),
"process_id": self.process_id,
}
class ServerManager:
"""Manages allocation of Factorio servers across concurrent jobs"""
def __init__(self, max_allocation_time_hours: float = 2.0):
"""Initialize server manager
Args:
max_allocation_time_hours: Maximum time to hold a server allocation
"""
self._lock = threading.Lock()
self._allocations: Dict[int, ServerAllocation] = {}
self._available_servers: List[int] = []
self._allocated_servers: Set[int] = set()
self.max_allocation_time = timedelta(hours=max_allocation_time_hours)
self._initialized = False
def _discover_servers(self) -> bool:
"""Discover available Factorio servers
Returns:
True if servers were found, False otherwise
"""
try:
ips, udp_ports, tcp_ports = get_local_container_ips()
if not tcp_ports:
print("⚠️ No Factorio containers found")
return False
self._available_servers = list(range(len(tcp_ports)))
self._server_info = {
i: {
"address": ips[i],
"tcp_port": tcp_ports[i],
"udp_port": udp_ports[i],
}
for i in range(len(tcp_ports))
}
print(f"🖥️ Discovered {len(tcp_ports)} Factorio servers:")
for i, (ip, tcp_port) in enumerate(zip(ips, tcp_ports)):
print(f" Server {i}: {ip}:{tcp_port}")
self._initialized = True
return True
except Exception as e:
print(f"❌ Error discovering servers: {e}")
return False
def initialize(self) -> bool:
"""Initialize server discovery
Returns:
True if initialization was successful
"""
with self._lock:
if not self._initialized:
return self._discover_servers()
return True
def get_available_server_count(self) -> int:
"""Get number of currently available servers"""
with self._lock:
if not self._initialized:
self._discover_servers()
return len(self._available_servers) - len(self._allocated_servers)
def get_total_server_count(self) -> int:
"""Get total number of discovered servers"""
with self._lock:
if not self._initialized:
self._discover_servers()
return len(self._available_servers)
def allocate_server(
self, job_id: str, process_id: Optional[int] = None
) -> Optional[ServerAllocation]:
"""Allocate a server for a job
Args:
job_id: Unique identifier for the job
process_id: Optional process ID for tracking
Returns:
ServerAllocation if successful, None if no servers available
"""
with self._lock:
# Initialize if needed
if not self._initialized:
if not self._discover_servers():
return None
# Clean up expired allocations
self._cleanup_expired_allocations()
# Find available server
available_servers = [
server_id
for server_id in self._available_servers
if server_id not in self._allocated_servers
]
if not available_servers:
print(
f"⚠️ No servers available for job {job_id} (all {len(self._available_servers)} servers allocated)"
)
return None
# Allocate first available server
server_id = available_servers[0]
server_info = self._server_info[server_id]
allocation = ServerAllocation(
server_id=server_id,
server_address=server_info["address"],
tcp_port=server_info["tcp_port"],
udp_port=server_info["udp_port"],
job_id=job_id,
allocated_at=datetime.now(),
process_id=process_id,
)
self._allocations[server_id] = allocation
self._allocated_servers.add(server_id)
print(
f"🖥️ Allocated server {server_id} ({allocation.server_address}:{allocation.tcp_port}) to job {job_id}"
)
return allocation
def release_server(self, job_id: str) -> bool:
"""Release server allocation for a job
Args:
job_id: Job identifier to release
Returns:
True if server was found and released
"""
with self._lock:
# Find allocation by job_id
server_id = None
for sid, allocation in self._allocations.items():
if allocation.job_id == job_id:
server_id = sid
break
if server_id is not None:
allocation = self._allocations.pop(server_id)
self._allocated_servers.remove(server_id)
print(f"🔓 Released server {server_id} from job {job_id}")
return True
return False
def release_server_by_id(self, server_id: int) -> bool:
"""Release server allocation by server ID
Args:
server_id: Server ID to release
Returns:
True if server was found and released
"""
with self._lock:
if server_id in self._allocations:
allocation = self._allocations.pop(server_id)
self._allocated_servers.remove(server_id)
print(
f"🔓 Released server {server_id} (was allocated to {allocation.job_id})"
)
return True
return False
def _cleanup_expired_allocations(self):
"""Clean up allocations that have been held too long (called with lock held)"""
current_time = datetime.now()
expired_servers = []
for server_id, allocation in self._allocations.items():
if current_time - allocation.allocated_at > self.max_allocation_time:
expired_servers.append(server_id)
for server_id in expired_servers:
allocation = self._allocations.pop(server_id)
self._allocated_servers.remove(server_id)
print(
f"⏰ Released expired allocation: server {server_id} (was allocated to {allocation.job_id})"
)
def get_allocation_status(self) -> Dict:
"""Get current allocation status
Returns:
Dictionary with allocation information
"""
with self._lock:
if not self._initialized:
self._discover_servers()
self._cleanup_expired_allocations()
return {
"total_servers": len(self._available_servers),
"allocated_servers": len(self._allocated_servers),
"available_servers": len(self._available_servers)
- len(self._allocated_servers),
"allocations": [
allocation.to_dict() for allocation in self._allocations.values()
],
"initialized": self._initialized,
}
def get_server_assignment_for_job(self, job_id: str) -> Optional[Dict]:
"""Get server assignment for a specific job
Args:
job_id: Job identifier
Returns:
Dictionary with server info or None if not found
"""
with self._lock:
for allocation in self._allocations.values():
if allocation.job_id == job_id:
return {
"server_id": allocation.server_id,
"address": allocation.server_address,
"tcp_port": allocation.tcp_port,
"udp_port": allocation.udp_port,
"allocated_at": allocation.allocated_at.isoformat(),
}
return None
def force_release_all(self):
"""Force release all server allocations (emergency cleanup)"""
with self._lock:
released_count = len(self._allocations)
self._allocations.clear()
self._allocated_servers.clear()
if released_count > 0:
print(f"🧹 Force released all {released_count} server allocations")
def print_status(self):
"""Print current server allocation status"""
status = self.get_allocation_status()
print("🖥️ Server Allocation Status:")
print(f" Total servers: {status['total_servers']}")
print(f" Available: {status['available_servers']}")
print(f" Allocated: {status['allocated_servers']}")
if status["allocations"]:
print(" Current allocations:")
for alloc in status["allocations"]:
print(
f" Server {alloc['server_id']}: {alloc['job_id']} "
f"(since {alloc['allocated_at'][:19]})"
)
# Global server manager instance
_global_server_manager: Optional[ServerManager] = None
def get_server_manager() -> ServerManager:
"""Get or create global server manager instance"""
global _global_server_manager
if _global_server_manager is None:
_global_server_manager = ServerManager()
_global_server_manager.initialize()
return _global_server_manager
if __name__ == "__main__":
# Test server manager
manager = ServerManager()
manager.initialize()
manager.print_status()
# Test allocation
alloc1 = manager.allocate_server("test_job_1")
alloc2 = manager.allocate_server("test_job_2")
manager.print_status()
if alloc1:
manager.release_server("test_job_1")
if alloc2:
manager.release_server("test_job_2")
manager.print_status()

View File

@@ -88,20 +88,31 @@ class RunJob:
class SweepManager:
"""Manages large-scale evaluation sweeps"""
"""Manages large-scale evaluation sweeps
def __init__(self, config: SweepConfig):
Supports both new sweeps and resuming existing ones. When resuming, the manager
will automatically skip completed runs and retry partial/failed runs.
"""
def __init__(self, config: SweepConfig, existing_sweep_id: Optional[str] = None):
"""Initialize sweep manager
Args:
config: SweepConfig with sweep parameters
existing_sweep_id: Optional sweep ID to resume an existing sweep
"""
self.config = config
self.is_resuming = existing_sweep_id is not None
if existing_sweep_id:
self.sweep_id = existing_sweep_id
print(f"🔄 Resuming sweep: {self.sweep_id}")
else:
# Generate unique sweep ID with timestamp and short UUID
timestamp = datetime.now().strftime("%Y%m%d_%H%M%S")
short_uuid = str(uuid.uuid4())[:8]
self.sweep_id = f"{config.name}_{timestamp}_{short_uuid}"
print(f"🆕 Starting new sweep: {self.sweep_id}")
self.jobs: List[RunJob] = []
self.active_processes: Dict[str, multiprocessing.Process] = {}
@@ -139,18 +150,155 @@ class SweepManager:
f"🖥️ Available servers: {total_servers}, Max concurrent: {self.config.max_concurrent_processes}"
)
def generate_jobs(self) -> List[RunJob]:
"""Generate all run jobs for the sweep
# If resuming, we'll load existing state in run_sweep()
@classmethod
def resume_sweep(cls, config: SweepConfig, sweep_id: str) -> "SweepManager":
"""Create a SweepManager to resume an existing sweep
Args:
config: SweepConfig with sweep parameters (should match original config)
sweep_id: ID of the existing sweep to resume
Returns:
List of RunJob objects
SweepManager instance configured to resume the sweep
"""
return cls(config, existing_sweep_id=sweep_id)
async def load_existing_sweep_state(self) -> Dict[str, Any]:
"""Load existing sweep state from database
Returns:
Dictionary with information about completed runs
"""
if not self.is_resuming:
return {"completed_runs": {}, "partial_runs": {}}
print(f"🔍 Loading existing state for sweep: {self.sweep_id}")
# Initialize database analyzer if not already done
if not self.database_analyzer:
self.database_analyzer = DatabaseAnalyzer()
try:
# Get existing trajectories for this sweep
existing_df = (
await self.database_analyzer.get_trajectory_summaries_by_sweep(
self.sweep_id
)
)
if existing_df.empty:
print(f"⚠️ No existing data found for sweep {self.sweep_id}")
return {"completed_runs": {}, "partial_runs": {}}
print(f"📊 Found {len(existing_df)} existing trajectories")
completed_runs = {} # key: (model, task, trial), value: version
partial_runs = {} # key: (model, task, trial), value: {version, status}
# Group by version to identify complete vs partial runs
for _, row in existing_df.iterrows():
model = row["model"]
version_desc = row["version_description"]
version = row["version"]
instance = row["instance"] # This is the trial number
# Extract task from version_description
# Assuming version_description contains "type:task_name"
task = "unknown_task"
if version_desc and "type:" in version_desc:
task = version_desc.split("type:")[1].split("\n")[0].strip()
elif version_desc:
# Fallback: use first line or part of version_description
task = version_desc.split("\n")[0].strip()
# Key for identifying unique run
run_key = (model, task, instance)
# Check if this is a complete trajectory
# A trajectory is considered complete if it has steps and ended properly
num_steps = row.get("num_steps", 0)
final_reward = row.get("final_reward", 0)
# Consider a run complete if it has reasonable number of steps
# This is a heuristic - might need adjustment based on typical run patterns
is_complete = num_steps > 5 # Adjust this threshold as needed
if is_complete:
completed_runs[run_key] = version
else:
partial_runs[run_key] = {
"version": version,
"num_steps": num_steps,
"final_reward": final_reward,
}
print(f"✅ Found {len(completed_runs)} completed runs")
print(f"⚠️ Found {len(partial_runs)} partial/incomplete runs")
# Log some examples
if completed_runs:
print("Examples of completed runs:")
for i, (model, task, trial) in enumerate(
list(completed_runs.keys())[:3]
):
print(f" - {model} on {task} trial {trial}")
if partial_runs:
print("Examples of partial runs (will be retried):")
for i, (model, task, trial) in enumerate(list(partial_runs.keys())[:3]):
info = partial_runs[(model, task, trial)]
print(
f" - {model} on {task} trial {trial} ({info['num_steps']} steps)"
)
return {
"completed_runs": completed_runs,
"partial_runs": partial_runs,
"total_existing": len(existing_df),
}
except Exception as e:
print(f"❌ Error loading existing sweep state: {e}")
return {"completed_runs": {}, "partial_runs": {}}
def generate_jobs(
self, existing_state: Optional[Dict[str, Any]] = None
) -> List[RunJob]:
"""Generate all run jobs for the sweep, excluding completed ones if resuming
Args:
existing_state: Optional dictionary with completed and partial run information
Returns:
List of RunJob objects to execute
"""
jobs = []
completed_runs = (
existing_state.get("completed_runs", {}) if existing_state else {}
)
partial_runs = existing_state.get("partial_runs", {}) if existing_state else {}
skipped_count = 0
retry_count = 0
# Generate all combinations of models and tasks
for model, task in itertools.product(self.config.models, self.config.tasks):
for trial in range(self.config.num_trials_per_config):
run_key = (model, task, trial)
job_id = f"{model}_{task}_trial{trial:02d}"
# Skip if already completed
if run_key in completed_runs:
skipped_count += 1
# Still track these for completed_versions list
existing_version = completed_runs[run_key]
if existing_version not in self.completed_versions:
self.completed_versions.append(existing_version)
continue
# Create job (will run for new jobs and retry for partial jobs)
job = RunJob(
job_id=job_id,
model=model,
@@ -158,6 +306,14 @@ class SweepManager:
trial_number=trial,
sweep_id=self.sweep_id,
)
# If this is a partial run, mark it for retry
if run_key in partial_runs:
retry_count += 1
job.retry_count = 1 # Mark that this is a retry
job.status = "pending" # Will be retried
print(f"🔄 Will retry partial run: {job_id}")
jobs.append(job)
# Shuffle if requested to distribute load
@@ -165,6 +321,20 @@ class SweepManager:
random.shuffle(jobs)
self.jobs = jobs
# Log summary
total_expected = (
len(self.config.models)
* len(self.config.tasks)
* self.config.num_trials_per_config
)
print("📋 Job generation summary:")
print(f" Total expected jobs: {total_expected}")
print(f" Already completed: {skipped_count}")
print(f" Will retry (partial): {retry_count}")
print(f" New jobs to run: {len(jobs) - retry_count}")
print(f" Total jobs to execute: {len(jobs)}")
return jobs
async def run_sweep(self) -> Dict[str, Any]:
@@ -180,10 +350,14 @@ class SweepManager:
self.start_time = datetime.now()
# Generate all jobs
self.generate_jobs()
# Load existing state if resuming
existing_state = await self.load_existing_sweep_state()
# Initialize database analyzer for monitoring
# Generate all jobs (excluding completed ones if resuming)
self.generate_jobs(existing_state)
# Initialize database analyzer for monitoring (if not already initialized)
if not self.database_analyzer:
self.database_analyzer = DatabaseAnalyzer()
# Get starting version numbers
@@ -293,7 +467,17 @@ class SweepManager:
# Log to WandB if enabled
if self.wandb_logger:
logger = self.wandb_logger.create_run_logger(
job.job_id, job.model, job.task, job.version
job.job_id,
job.model,
job.task,
job.version,
config={
"sweep_id": self.sweep_id,
"is_resume": self.is_resuming,
"retry_count": job.retry_count,
"trial_number": job.trial_number,
},
tags=[f"sweep:{self.sweep_id}"],
)
logger.log_metrics(
{
@@ -301,6 +485,10 @@ class SweepManager:
"job/server_id": server_allocation.server_id,
"job/server_address": server_allocation.server_address,
"job/server_port": server_allocation.tcp_port,
"job/sweep_id": self.sweep_id,
"job/is_resume": self.is_resuming,
"job/retry_count": job.retry_count,
"job/completion_status": "running",
}
)
@@ -418,13 +606,52 @@ class SweepManager:
if self.wandb_logger:
logger = self.wandb_logger.get_logger(job_id)
if logger:
logger.log_metrics({"job/status": "completed"})
logger.log_metrics(
{
"job/status": "completed",
"job/completion_status": "successful",
"job/exit_code": process.exitcode,
"job/duration_minutes": (
job.end_time - job.start_time
).total_seconds()
/ 60
if job.start_time
else 0,
"job/final_retry_count": job.retry_count,
}
)
else:
job.status = "failed"
job.error_message = f"Process exited with code {process.exitcode}"
print(f"❌ Job {job_id} failed with exit code {process.exitcode}")
# Log failure to WandB
if self.wandb_logger:
logger = self.wandb_logger.get_logger(job_id)
if logger:
will_retry = (
self.config.retry_failed_runs
and job.retry_count < self.config.max_retries
)
logger.log_metrics(
{
"job/status": "failed",
"job/completion_status": "will_retry"
if will_retry
else "failed_final",
"job/exit_code": process.exitcode,
"job/error_message": job.error_message,
"job/duration_minutes": (
job.end_time - job.start_time
).total_seconds()
/ 60
if job.start_time
else 0,
"job/retry_count_at_failure": job.retry_count,
}
)
# Retry if configured and retries available
if (
self.config.retry_failed_runs
@@ -635,7 +862,13 @@ class SweepManager:
"all_models",
"final_summary",
0,
config=report["sweep_config"],
config={
**report["sweep_config"],
"sweep_id": self.sweep_id,
"is_resume": self.is_resuming,
"sweep_completion_status": "completed",
},
tags=[f"sweep:{self.sweep_id}", "sweep_summary"],
)
summary_logger.log_sweep_summary(
@@ -649,6 +882,25 @@ class SweepManager:
# Log model comparison table
summary_logger.log_model_comparison_table(results_by_model)
# Log sweep completion metrics
summary_logger.log_metrics(
{
"sweep/completion_status": "completed",
"sweep/is_resume": self.is_resuming,
"sweep/total_completed_jobs": report["execution_summary"][
"completed_jobs"
],
"sweep/total_failed_jobs": report["execution_summary"][
"failed_jobs"
],
"sweep/success_rate": report["execution_summary"][
"completed_jobs"
]
/ max(report["execution_summary"]["total_jobs"], 1),
"sweep/final_timestamp": datetime.now().timestamp(),
}
)
# Save report to file if output directory specified
if self.config.output_dir:
report_path = (