mirror of
https://github.com/JackHopkins/factorio-learning-environment.git
synced 2025-09-06 13:23:58 +00:00
other files
This commit is contained in:
@@ -165,6 +165,38 @@ manager = SweepManager(config)
|
||||
results = await manager.run_sweep()
|
||||
```
|
||||
|
||||
#### Resuming Failed Sweeps
|
||||
|
||||
The SweepManager now supports resuming failed sweeps without duplicating completed runs:
|
||||
|
||||
```python
|
||||
# Resume an existing sweep
|
||||
existing_sweep_id = "my_experiment_20241201_120000_abcd1234"
|
||||
manager = SweepManager(config, existing_sweep_id=existing_sweep_id)
|
||||
|
||||
# Alternative: using class method
|
||||
manager = SweepManager.resume_sweep(config, existing_sweep_id)
|
||||
|
||||
# The manager will automatically:
|
||||
# - Skip completed runs
|
||||
# - Retry partial/failed runs
|
||||
# - Continue with remaining jobs
|
||||
results = await manager.run_sweep()
|
||||
```
|
||||
|
||||
Enhanced WandB metadata for resumed sweeps includes:
|
||||
|
||||
- `sweep_id`: Unique identifier for the sweep
|
||||
- `is_resume`: Boolean indicating if this is a resumed sweep
|
||||
- `completion_status`: Track run completion ("running", "successful", "failed_final", "will_retry")
|
||||
- `retry_count`: Number of retry attempts for each job
|
||||
- Tags: Include `sweep:{sweep_id}` for easy filtering
|
||||
|
||||
Filter in WandB using:
|
||||
- `config.sweep_id = "your_sweep_id"`
|
||||
- `config.completion_status = "successful"`
|
||||
- `tags: sweep:your_sweep_id`
|
||||
|
||||
### ResultsVisualizer
|
||||
|
||||
Generate analysis plots:
|
||||
@@ -197,6 +229,7 @@ See the `examples/` directory for complete usage examples:
|
||||
|
||||
- `example_sweep_config.py`: Example sweep configurations
|
||||
- `analyze_sweep_results.py`: Analysis script with various commands
|
||||
- `resume_sweep_example.py`: Example of resuming failed sweeps
|
||||
|
||||
### Running Examples
|
||||
|
||||
|
@@ -132,7 +132,8 @@ class DatabaseAnalyzer:
|
||||
THEN achievements_json
|
||||
ELSE NULL END,
|
||||
'; '
|
||||
) as all_achievements
|
||||
) as all_achievements,
|
||||
STRING_AGG(DISTINCT meta::text, '; ') as meta_aggregated
|
||||
FROM programs
|
||||
WHERE version IN ({version_list})
|
||||
GROUP BY version, version_description, model, instance
|
||||
@@ -147,6 +148,53 @@ class DatabaseAnalyzer:
|
||||
results = await self.db_client.execute_query(query)
|
||||
return pd.DataFrame(results)
|
||||
|
||||
async def get_trajectory_summaries_by_sweep(self, sweep_id: str) -> pd.DataFrame:
|
||||
"""Get trajectory-level summaries for a specific sweep
|
||||
|
||||
Args:
|
||||
sweep_id: Sweep identifier to filter by
|
||||
|
||||
Returns:
|
||||
DataFrame with one row per trajectory (version + instance)
|
||||
"""
|
||||
await self.ensure_connection()
|
||||
|
||||
query = """
|
||||
WITH trajectory_stats AS (
|
||||
SELECT
|
||||
version,
|
||||
version_description,
|
||||
model,
|
||||
instance,
|
||||
MAX(depth) as max_depth,
|
||||
MAX(value) as final_reward,
|
||||
MAX(raw_reward) as max_raw_reward,
|
||||
COUNT(*) as num_steps,
|
||||
SUM(token_usage) as total_tokens,
|
||||
SUM(completion_token_usage) as total_completion_tokens,
|
||||
MIN(created_at) as start_time,
|
||||
MAX(created_at) as end_time,
|
||||
STRING_AGG(
|
||||
CASE WHEN achievements_json != '{{}}' AND achievements_json IS NOT NULL
|
||||
THEN achievements_json
|
||||
ELSE NULL END,
|
||||
'; '
|
||||
) as all_achievements,
|
||||
STRING_AGG(DISTINCT meta::text, '; ') as meta_aggregated
|
||||
FROM programs
|
||||
WHERE meta->>'sweep_id' = %s
|
||||
GROUP BY version, version_description, model, instance
|
||||
)
|
||||
SELECT
|
||||
*,
|
||||
EXTRACT(EPOCH FROM (end_time - start_time)) as duration_seconds
|
||||
FROM trajectory_stats
|
||||
ORDER BY version, instance
|
||||
"""
|
||||
|
||||
results = await self.db_client.execute_query(query, (sweep_id,))
|
||||
return pd.DataFrame(results)
|
||||
|
||||
async def get_model_comparison(
|
||||
self,
|
||||
models: List[str],
|
||||
|
63
fle/eval/analysis/examples/resume_sweep_example.py
Normal file
63
fle/eval/analysis/examples/resume_sweep_example.py
Normal file
@@ -0,0 +1,63 @@
|
||||
"""
|
||||
Example demonstrating how to resume a failed sweep.
|
||||
|
||||
This example shows how to restart a sweep without duplicating completed runs.
|
||||
"""
|
||||
|
||||
from fle.eval.analysis.sweep_manager import SweepManager, SweepConfig
|
||||
|
||||
|
||||
async def resume_sweep_example():
|
||||
"""Example of resuming a failed sweep"""
|
||||
|
||||
# Create the same config as the original sweep
|
||||
config = SweepConfig(
|
||||
name="my_experiment",
|
||||
models=["gpt-4", "claude-3-sonnet"],
|
||||
tasks=["craft_iron_plate", "build_furnace", "mine_coal"],
|
||||
num_trials_per_config=8,
|
||||
max_concurrent_processes=4,
|
||||
api_key_config_file="path/to/api_keys.json",
|
||||
)
|
||||
|
||||
# Resume sweep with existing sweep ID
|
||||
# You can get the sweep ID from the original run logs or WandB
|
||||
existing_sweep_id = "my_experiment_20241201_120000_abcd1234"
|
||||
|
||||
# Method 1: Using the constructor
|
||||
sweep_manager = SweepManager(config, existing_sweep_id=existing_sweep_id)
|
||||
|
||||
# Method 2: Using the class method (equivalent)
|
||||
# sweep_manager = SweepManager.resume_sweep(config, existing_sweep_id)
|
||||
|
||||
print("Starting sweep resume...")
|
||||
results = await sweep_manager.run_sweep()
|
||||
|
||||
print(f"Sweep completed! Results: {results}")
|
||||
|
||||
|
||||
async def new_sweep_example():
|
||||
"""Example of starting a new sweep (for comparison)"""
|
||||
|
||||
config = SweepConfig(
|
||||
name="my_new_experiment",
|
||||
models=["gpt-4", "claude-3-sonnet"],
|
||||
tasks=["craft_iron_plate", "build_furnace"],
|
||||
num_trials_per_config=4,
|
||||
)
|
||||
|
||||
# Create new sweep (no existing_sweep_id)
|
||||
sweep_manager = SweepManager(config)
|
||||
|
||||
await sweep_manager.run_sweep()
|
||||
print(f"New sweep completed! Sweep ID: {sweep_manager.sweep_id}")
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
# To resume a sweep:
|
||||
# asyncio.run(resume_sweep_example())
|
||||
|
||||
# To start a new sweep:
|
||||
# asyncio.run(new_sweep_example())
|
||||
|
||||
print("Example script - uncomment the appropriate line above to run")
|
338
fle/eval/analysis/server_manager.py
Normal file
338
fle/eval/analysis/server_manager.py
Normal file
@@ -0,0 +1,338 @@
|
||||
"""
|
||||
Server management for concurrent evaluations.
|
||||
|
||||
Handles allocation and tracking of Factorio server instances across
|
||||
multiple concurrent evaluation processes.
|
||||
"""
|
||||
|
||||
import threading
|
||||
from dataclasses import dataclass
|
||||
from typing import List, Dict, Optional, Set
|
||||
from datetime import datetime, timedelta
|
||||
|
||||
from fle.commons.cluster_ips import get_local_container_ips
|
||||
|
||||
|
||||
@dataclass
|
||||
class ServerAllocation:
|
||||
"""Tracks allocation of a Factorio server to a job"""
|
||||
|
||||
server_id: int
|
||||
server_address: str
|
||||
tcp_port: int
|
||||
udp_port: int
|
||||
job_id: str
|
||||
allocated_at: datetime
|
||||
process_id: Optional[int] = None
|
||||
|
||||
def to_dict(self) -> Dict:
|
||||
"""Convert to dictionary for logging"""
|
||||
return {
|
||||
"server_id": self.server_id,
|
||||
"server_address": self.server_address,
|
||||
"tcp_port": self.tcp_port,
|
||||
"udp_port": self.udp_port,
|
||||
"job_id": self.job_id,
|
||||
"allocated_at": self.allocated_at.isoformat(),
|
||||
"process_id": self.process_id,
|
||||
}
|
||||
|
||||
|
||||
class ServerManager:
|
||||
"""Manages allocation of Factorio servers across concurrent jobs"""
|
||||
|
||||
def __init__(self, max_allocation_time_hours: float = 2.0):
|
||||
"""Initialize server manager
|
||||
|
||||
Args:
|
||||
max_allocation_time_hours: Maximum time to hold a server allocation
|
||||
"""
|
||||
self._lock = threading.Lock()
|
||||
self._allocations: Dict[int, ServerAllocation] = {}
|
||||
self._available_servers: List[int] = []
|
||||
self._allocated_servers: Set[int] = set()
|
||||
self.max_allocation_time = timedelta(hours=max_allocation_time_hours)
|
||||
self._initialized = False
|
||||
|
||||
def _discover_servers(self) -> bool:
|
||||
"""Discover available Factorio servers
|
||||
|
||||
Returns:
|
||||
True if servers were found, False otherwise
|
||||
"""
|
||||
try:
|
||||
ips, udp_ports, tcp_ports = get_local_container_ips()
|
||||
|
||||
if not tcp_ports:
|
||||
print("⚠️ No Factorio containers found")
|
||||
return False
|
||||
|
||||
self._available_servers = list(range(len(tcp_ports)))
|
||||
self._server_info = {
|
||||
i: {
|
||||
"address": ips[i],
|
||||
"tcp_port": tcp_ports[i],
|
||||
"udp_port": udp_ports[i],
|
||||
}
|
||||
for i in range(len(tcp_ports))
|
||||
}
|
||||
|
||||
print(f"🖥️ Discovered {len(tcp_ports)} Factorio servers:")
|
||||
for i, (ip, tcp_port) in enumerate(zip(ips, tcp_ports)):
|
||||
print(f" Server {i}: {ip}:{tcp_port}")
|
||||
|
||||
self._initialized = True
|
||||
return True
|
||||
|
||||
except Exception as e:
|
||||
print(f"❌ Error discovering servers: {e}")
|
||||
return False
|
||||
|
||||
def initialize(self) -> bool:
|
||||
"""Initialize server discovery
|
||||
|
||||
Returns:
|
||||
True if initialization was successful
|
||||
"""
|
||||
with self._lock:
|
||||
if not self._initialized:
|
||||
return self._discover_servers()
|
||||
return True
|
||||
|
||||
def get_available_server_count(self) -> int:
|
||||
"""Get number of currently available servers"""
|
||||
with self._lock:
|
||||
if not self._initialized:
|
||||
self._discover_servers()
|
||||
return len(self._available_servers) - len(self._allocated_servers)
|
||||
|
||||
def get_total_server_count(self) -> int:
|
||||
"""Get total number of discovered servers"""
|
||||
with self._lock:
|
||||
if not self._initialized:
|
||||
self._discover_servers()
|
||||
return len(self._available_servers)
|
||||
|
||||
def allocate_server(
|
||||
self, job_id: str, process_id: Optional[int] = None
|
||||
) -> Optional[ServerAllocation]:
|
||||
"""Allocate a server for a job
|
||||
|
||||
Args:
|
||||
job_id: Unique identifier for the job
|
||||
process_id: Optional process ID for tracking
|
||||
|
||||
Returns:
|
||||
ServerAllocation if successful, None if no servers available
|
||||
"""
|
||||
with self._lock:
|
||||
# Initialize if needed
|
||||
if not self._initialized:
|
||||
if not self._discover_servers():
|
||||
return None
|
||||
|
||||
# Clean up expired allocations
|
||||
self._cleanup_expired_allocations()
|
||||
|
||||
# Find available server
|
||||
available_servers = [
|
||||
server_id
|
||||
for server_id in self._available_servers
|
||||
if server_id not in self._allocated_servers
|
||||
]
|
||||
|
||||
if not available_servers:
|
||||
print(
|
||||
f"⚠️ No servers available for job {job_id} (all {len(self._available_servers)} servers allocated)"
|
||||
)
|
||||
return None
|
||||
|
||||
# Allocate first available server
|
||||
server_id = available_servers[0]
|
||||
server_info = self._server_info[server_id]
|
||||
|
||||
allocation = ServerAllocation(
|
||||
server_id=server_id,
|
||||
server_address=server_info["address"],
|
||||
tcp_port=server_info["tcp_port"],
|
||||
udp_port=server_info["udp_port"],
|
||||
job_id=job_id,
|
||||
allocated_at=datetime.now(),
|
||||
process_id=process_id,
|
||||
)
|
||||
|
||||
self._allocations[server_id] = allocation
|
||||
self._allocated_servers.add(server_id)
|
||||
|
||||
print(
|
||||
f"🖥️ Allocated server {server_id} ({allocation.server_address}:{allocation.tcp_port}) to job {job_id}"
|
||||
)
|
||||
|
||||
return allocation
|
||||
|
||||
def release_server(self, job_id: str) -> bool:
|
||||
"""Release server allocation for a job
|
||||
|
||||
Args:
|
||||
job_id: Job identifier to release
|
||||
|
||||
Returns:
|
||||
True if server was found and released
|
||||
"""
|
||||
with self._lock:
|
||||
# Find allocation by job_id
|
||||
server_id = None
|
||||
for sid, allocation in self._allocations.items():
|
||||
if allocation.job_id == job_id:
|
||||
server_id = sid
|
||||
break
|
||||
|
||||
if server_id is not None:
|
||||
allocation = self._allocations.pop(server_id)
|
||||
self._allocated_servers.remove(server_id)
|
||||
print(f"🔓 Released server {server_id} from job {job_id}")
|
||||
return True
|
||||
|
||||
return False
|
||||
|
||||
def release_server_by_id(self, server_id: int) -> bool:
|
||||
"""Release server allocation by server ID
|
||||
|
||||
Args:
|
||||
server_id: Server ID to release
|
||||
|
||||
Returns:
|
||||
True if server was found and released
|
||||
"""
|
||||
with self._lock:
|
||||
if server_id in self._allocations:
|
||||
allocation = self._allocations.pop(server_id)
|
||||
self._allocated_servers.remove(server_id)
|
||||
print(
|
||||
f"🔓 Released server {server_id} (was allocated to {allocation.job_id})"
|
||||
)
|
||||
return True
|
||||
|
||||
return False
|
||||
|
||||
def _cleanup_expired_allocations(self):
|
||||
"""Clean up allocations that have been held too long (called with lock held)"""
|
||||
current_time = datetime.now()
|
||||
expired_servers = []
|
||||
|
||||
for server_id, allocation in self._allocations.items():
|
||||
if current_time - allocation.allocated_at > self.max_allocation_time:
|
||||
expired_servers.append(server_id)
|
||||
|
||||
for server_id in expired_servers:
|
||||
allocation = self._allocations.pop(server_id)
|
||||
self._allocated_servers.remove(server_id)
|
||||
print(
|
||||
f"⏰ Released expired allocation: server {server_id} (was allocated to {allocation.job_id})"
|
||||
)
|
||||
|
||||
def get_allocation_status(self) -> Dict:
|
||||
"""Get current allocation status
|
||||
|
||||
Returns:
|
||||
Dictionary with allocation information
|
||||
"""
|
||||
with self._lock:
|
||||
if not self._initialized:
|
||||
self._discover_servers()
|
||||
|
||||
self._cleanup_expired_allocations()
|
||||
|
||||
return {
|
||||
"total_servers": len(self._available_servers),
|
||||
"allocated_servers": len(self._allocated_servers),
|
||||
"available_servers": len(self._available_servers)
|
||||
- len(self._allocated_servers),
|
||||
"allocations": [
|
||||
allocation.to_dict() for allocation in self._allocations.values()
|
||||
],
|
||||
"initialized": self._initialized,
|
||||
}
|
||||
|
||||
def get_server_assignment_for_job(self, job_id: str) -> Optional[Dict]:
|
||||
"""Get server assignment for a specific job
|
||||
|
||||
Args:
|
||||
job_id: Job identifier
|
||||
|
||||
Returns:
|
||||
Dictionary with server info or None if not found
|
||||
"""
|
||||
with self._lock:
|
||||
for allocation in self._allocations.values():
|
||||
if allocation.job_id == job_id:
|
||||
return {
|
||||
"server_id": allocation.server_id,
|
||||
"address": allocation.server_address,
|
||||
"tcp_port": allocation.tcp_port,
|
||||
"udp_port": allocation.udp_port,
|
||||
"allocated_at": allocation.allocated_at.isoformat(),
|
||||
}
|
||||
return None
|
||||
|
||||
def force_release_all(self):
|
||||
"""Force release all server allocations (emergency cleanup)"""
|
||||
with self._lock:
|
||||
released_count = len(self._allocations)
|
||||
self._allocations.clear()
|
||||
self._allocated_servers.clear()
|
||||
|
||||
if released_count > 0:
|
||||
print(f"🧹 Force released all {released_count} server allocations")
|
||||
|
||||
def print_status(self):
|
||||
"""Print current server allocation status"""
|
||||
status = self.get_allocation_status()
|
||||
|
||||
print("🖥️ Server Allocation Status:")
|
||||
print(f" Total servers: {status['total_servers']}")
|
||||
print(f" Available: {status['available_servers']}")
|
||||
print(f" Allocated: {status['allocated_servers']}")
|
||||
|
||||
if status["allocations"]:
|
||||
print(" Current allocations:")
|
||||
for alloc in status["allocations"]:
|
||||
print(
|
||||
f" Server {alloc['server_id']}: {alloc['job_id']} "
|
||||
f"(since {alloc['allocated_at'][:19]})"
|
||||
)
|
||||
|
||||
|
||||
# Global server manager instance
|
||||
_global_server_manager: Optional[ServerManager] = None
|
||||
|
||||
|
||||
def get_server_manager() -> ServerManager:
|
||||
"""Get or create global server manager instance"""
|
||||
global _global_server_manager
|
||||
|
||||
if _global_server_manager is None:
|
||||
_global_server_manager = ServerManager()
|
||||
_global_server_manager.initialize()
|
||||
|
||||
return _global_server_manager
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
# Test server manager
|
||||
manager = ServerManager()
|
||||
manager.initialize()
|
||||
manager.print_status()
|
||||
|
||||
# Test allocation
|
||||
alloc1 = manager.allocate_server("test_job_1")
|
||||
alloc2 = manager.allocate_server("test_job_2")
|
||||
|
||||
manager.print_status()
|
||||
|
||||
if alloc1:
|
||||
manager.release_server("test_job_1")
|
||||
if alloc2:
|
||||
manager.release_server("test_job_2")
|
||||
|
||||
manager.print_status()
|
@@ -88,20 +88,31 @@ class RunJob:
|
||||
|
||||
|
||||
class SweepManager:
|
||||
"""Manages large-scale evaluation sweeps"""
|
||||
"""Manages large-scale evaluation sweeps
|
||||
|
||||
def __init__(self, config: SweepConfig):
|
||||
Supports both new sweeps and resuming existing ones. When resuming, the manager
|
||||
will automatically skip completed runs and retry partial/failed runs.
|
||||
"""
|
||||
|
||||
def __init__(self, config: SweepConfig, existing_sweep_id: Optional[str] = None):
|
||||
"""Initialize sweep manager
|
||||
|
||||
Args:
|
||||
config: SweepConfig with sweep parameters
|
||||
existing_sweep_id: Optional sweep ID to resume an existing sweep
|
||||
"""
|
||||
self.config = config
|
||||
self.is_resuming = existing_sweep_id is not None
|
||||
|
||||
# Generate unique sweep ID with timestamp and short UUID
|
||||
timestamp = datetime.now().strftime("%Y%m%d_%H%M%S")
|
||||
short_uuid = str(uuid.uuid4())[:8]
|
||||
self.sweep_id = f"{config.name}_{timestamp}_{short_uuid}"
|
||||
if existing_sweep_id:
|
||||
self.sweep_id = existing_sweep_id
|
||||
print(f"🔄 Resuming sweep: {self.sweep_id}")
|
||||
else:
|
||||
# Generate unique sweep ID with timestamp and short UUID
|
||||
timestamp = datetime.now().strftime("%Y%m%d_%H%M%S")
|
||||
short_uuid = str(uuid.uuid4())[:8]
|
||||
self.sweep_id = f"{config.name}_{timestamp}_{short_uuid}"
|
||||
print(f"🆕 Starting new sweep: {self.sweep_id}")
|
||||
|
||||
self.jobs: List[RunJob] = []
|
||||
self.active_processes: Dict[str, multiprocessing.Process] = {}
|
||||
@@ -139,18 +150,155 @@ class SweepManager:
|
||||
f"🖥️ Available servers: {total_servers}, Max concurrent: {self.config.max_concurrent_processes}"
|
||||
)
|
||||
|
||||
def generate_jobs(self) -> List[RunJob]:
|
||||
"""Generate all run jobs for the sweep
|
||||
# If resuming, we'll load existing state in run_sweep()
|
||||
|
||||
@classmethod
|
||||
def resume_sweep(cls, config: SweepConfig, sweep_id: str) -> "SweepManager":
|
||||
"""Create a SweepManager to resume an existing sweep
|
||||
|
||||
Args:
|
||||
config: SweepConfig with sweep parameters (should match original config)
|
||||
sweep_id: ID of the existing sweep to resume
|
||||
|
||||
Returns:
|
||||
List of RunJob objects
|
||||
SweepManager instance configured to resume the sweep
|
||||
"""
|
||||
return cls(config, existing_sweep_id=sweep_id)
|
||||
|
||||
async def load_existing_sweep_state(self) -> Dict[str, Any]:
|
||||
"""Load existing sweep state from database
|
||||
|
||||
Returns:
|
||||
Dictionary with information about completed runs
|
||||
"""
|
||||
if not self.is_resuming:
|
||||
return {"completed_runs": {}, "partial_runs": {}}
|
||||
|
||||
print(f"🔍 Loading existing state for sweep: {self.sweep_id}")
|
||||
|
||||
# Initialize database analyzer if not already done
|
||||
if not self.database_analyzer:
|
||||
self.database_analyzer = DatabaseAnalyzer()
|
||||
|
||||
try:
|
||||
# Get existing trajectories for this sweep
|
||||
existing_df = (
|
||||
await self.database_analyzer.get_trajectory_summaries_by_sweep(
|
||||
self.sweep_id
|
||||
)
|
||||
)
|
||||
|
||||
if existing_df.empty:
|
||||
print(f"⚠️ No existing data found for sweep {self.sweep_id}")
|
||||
return {"completed_runs": {}, "partial_runs": {}}
|
||||
|
||||
print(f"📊 Found {len(existing_df)} existing trajectories")
|
||||
|
||||
completed_runs = {} # key: (model, task, trial), value: version
|
||||
partial_runs = {} # key: (model, task, trial), value: {version, status}
|
||||
|
||||
# Group by version to identify complete vs partial runs
|
||||
for _, row in existing_df.iterrows():
|
||||
model = row["model"]
|
||||
version_desc = row["version_description"]
|
||||
version = row["version"]
|
||||
instance = row["instance"] # This is the trial number
|
||||
|
||||
# Extract task from version_description
|
||||
# Assuming version_description contains "type:task_name"
|
||||
task = "unknown_task"
|
||||
if version_desc and "type:" in version_desc:
|
||||
task = version_desc.split("type:")[1].split("\n")[0].strip()
|
||||
elif version_desc:
|
||||
# Fallback: use first line or part of version_description
|
||||
task = version_desc.split("\n")[0].strip()
|
||||
|
||||
# Key for identifying unique run
|
||||
run_key = (model, task, instance)
|
||||
|
||||
# Check if this is a complete trajectory
|
||||
# A trajectory is considered complete if it has steps and ended properly
|
||||
num_steps = row.get("num_steps", 0)
|
||||
final_reward = row.get("final_reward", 0)
|
||||
|
||||
# Consider a run complete if it has reasonable number of steps
|
||||
# This is a heuristic - might need adjustment based on typical run patterns
|
||||
is_complete = num_steps > 5 # Adjust this threshold as needed
|
||||
|
||||
if is_complete:
|
||||
completed_runs[run_key] = version
|
||||
else:
|
||||
partial_runs[run_key] = {
|
||||
"version": version,
|
||||
"num_steps": num_steps,
|
||||
"final_reward": final_reward,
|
||||
}
|
||||
|
||||
print(f"✅ Found {len(completed_runs)} completed runs")
|
||||
print(f"⚠️ Found {len(partial_runs)} partial/incomplete runs")
|
||||
|
||||
# Log some examples
|
||||
if completed_runs:
|
||||
print("Examples of completed runs:")
|
||||
for i, (model, task, trial) in enumerate(
|
||||
list(completed_runs.keys())[:3]
|
||||
):
|
||||
print(f" - {model} on {task} trial {trial}")
|
||||
|
||||
if partial_runs:
|
||||
print("Examples of partial runs (will be retried):")
|
||||
for i, (model, task, trial) in enumerate(list(partial_runs.keys())[:3]):
|
||||
info = partial_runs[(model, task, trial)]
|
||||
print(
|
||||
f" - {model} on {task} trial {trial} ({info['num_steps']} steps)"
|
||||
)
|
||||
|
||||
return {
|
||||
"completed_runs": completed_runs,
|
||||
"partial_runs": partial_runs,
|
||||
"total_existing": len(existing_df),
|
||||
}
|
||||
|
||||
except Exception as e:
|
||||
print(f"❌ Error loading existing sweep state: {e}")
|
||||
return {"completed_runs": {}, "partial_runs": {}}
|
||||
|
||||
def generate_jobs(
|
||||
self, existing_state: Optional[Dict[str, Any]] = None
|
||||
) -> List[RunJob]:
|
||||
"""Generate all run jobs for the sweep, excluding completed ones if resuming
|
||||
|
||||
Args:
|
||||
existing_state: Optional dictionary with completed and partial run information
|
||||
|
||||
Returns:
|
||||
List of RunJob objects to execute
|
||||
"""
|
||||
jobs = []
|
||||
completed_runs = (
|
||||
existing_state.get("completed_runs", {}) if existing_state else {}
|
||||
)
|
||||
partial_runs = existing_state.get("partial_runs", {}) if existing_state else {}
|
||||
|
||||
skipped_count = 0
|
||||
retry_count = 0
|
||||
|
||||
# Generate all combinations of models and tasks
|
||||
for model, task in itertools.product(self.config.models, self.config.tasks):
|
||||
for trial in range(self.config.num_trials_per_config):
|
||||
run_key = (model, task, trial)
|
||||
job_id = f"{model}_{task}_trial{trial:02d}"
|
||||
|
||||
# Skip if already completed
|
||||
if run_key in completed_runs:
|
||||
skipped_count += 1
|
||||
# Still track these for completed_versions list
|
||||
existing_version = completed_runs[run_key]
|
||||
if existing_version not in self.completed_versions:
|
||||
self.completed_versions.append(existing_version)
|
||||
continue
|
||||
|
||||
# Create job (will run for new jobs and retry for partial jobs)
|
||||
job = RunJob(
|
||||
job_id=job_id,
|
||||
model=model,
|
||||
@@ -158,6 +306,14 @@ class SweepManager:
|
||||
trial_number=trial,
|
||||
sweep_id=self.sweep_id,
|
||||
)
|
||||
|
||||
# If this is a partial run, mark it for retry
|
||||
if run_key in partial_runs:
|
||||
retry_count += 1
|
||||
job.retry_count = 1 # Mark that this is a retry
|
||||
job.status = "pending" # Will be retried
|
||||
print(f"🔄 Will retry partial run: {job_id}")
|
||||
|
||||
jobs.append(job)
|
||||
|
||||
# Shuffle if requested to distribute load
|
||||
@@ -165,6 +321,20 @@ class SweepManager:
|
||||
random.shuffle(jobs)
|
||||
|
||||
self.jobs = jobs
|
||||
|
||||
# Log summary
|
||||
total_expected = (
|
||||
len(self.config.models)
|
||||
* len(self.config.tasks)
|
||||
* self.config.num_trials_per_config
|
||||
)
|
||||
print("📋 Job generation summary:")
|
||||
print(f" Total expected jobs: {total_expected}")
|
||||
print(f" Already completed: {skipped_count}")
|
||||
print(f" Will retry (partial): {retry_count}")
|
||||
print(f" New jobs to run: {len(jobs) - retry_count}")
|
||||
print(f" Total jobs to execute: {len(jobs)}")
|
||||
|
||||
return jobs
|
||||
|
||||
async def run_sweep(self) -> Dict[str, Any]:
|
||||
@@ -180,11 +350,15 @@ class SweepManager:
|
||||
|
||||
self.start_time = datetime.now()
|
||||
|
||||
# Generate all jobs
|
||||
self.generate_jobs()
|
||||
# Load existing state if resuming
|
||||
existing_state = await self.load_existing_sweep_state()
|
||||
|
||||
# Initialize database analyzer for monitoring
|
||||
self.database_analyzer = DatabaseAnalyzer()
|
||||
# Generate all jobs (excluding completed ones if resuming)
|
||||
self.generate_jobs(existing_state)
|
||||
|
||||
# Initialize database analyzer for monitoring (if not already initialized)
|
||||
if not self.database_analyzer:
|
||||
self.database_analyzer = DatabaseAnalyzer()
|
||||
|
||||
# Get starting version numbers
|
||||
base_version = await get_next_version()
|
||||
@@ -293,7 +467,17 @@ class SweepManager:
|
||||
# Log to WandB if enabled
|
||||
if self.wandb_logger:
|
||||
logger = self.wandb_logger.create_run_logger(
|
||||
job.job_id, job.model, job.task, job.version
|
||||
job.job_id,
|
||||
job.model,
|
||||
job.task,
|
||||
job.version,
|
||||
config={
|
||||
"sweep_id": self.sweep_id,
|
||||
"is_resume": self.is_resuming,
|
||||
"retry_count": job.retry_count,
|
||||
"trial_number": job.trial_number,
|
||||
},
|
||||
tags=[f"sweep:{self.sweep_id}"],
|
||||
)
|
||||
logger.log_metrics(
|
||||
{
|
||||
@@ -301,6 +485,10 @@ class SweepManager:
|
||||
"job/server_id": server_allocation.server_id,
|
||||
"job/server_address": server_allocation.server_address,
|
||||
"job/server_port": server_allocation.tcp_port,
|
||||
"job/sweep_id": self.sweep_id,
|
||||
"job/is_resume": self.is_resuming,
|
||||
"job/retry_count": job.retry_count,
|
||||
"job/completion_status": "running",
|
||||
}
|
||||
)
|
||||
|
||||
@@ -418,13 +606,52 @@ class SweepManager:
|
||||
if self.wandb_logger:
|
||||
logger = self.wandb_logger.get_logger(job_id)
|
||||
if logger:
|
||||
logger.log_metrics({"job/status": "completed"})
|
||||
logger.log_metrics(
|
||||
{
|
||||
"job/status": "completed",
|
||||
"job/completion_status": "successful",
|
||||
"job/exit_code": process.exitcode,
|
||||
"job/duration_minutes": (
|
||||
job.end_time - job.start_time
|
||||
).total_seconds()
|
||||
/ 60
|
||||
if job.start_time
|
||||
else 0,
|
||||
"job/final_retry_count": job.retry_count,
|
||||
}
|
||||
)
|
||||
|
||||
else:
|
||||
job.status = "failed"
|
||||
job.error_message = f"Process exited with code {process.exitcode}"
|
||||
print(f"❌ Job {job_id} failed with exit code {process.exitcode}")
|
||||
|
||||
# Log failure to WandB
|
||||
if self.wandb_logger:
|
||||
logger = self.wandb_logger.get_logger(job_id)
|
||||
if logger:
|
||||
will_retry = (
|
||||
self.config.retry_failed_runs
|
||||
and job.retry_count < self.config.max_retries
|
||||
)
|
||||
logger.log_metrics(
|
||||
{
|
||||
"job/status": "failed",
|
||||
"job/completion_status": "will_retry"
|
||||
if will_retry
|
||||
else "failed_final",
|
||||
"job/exit_code": process.exitcode,
|
||||
"job/error_message": job.error_message,
|
||||
"job/duration_minutes": (
|
||||
job.end_time - job.start_time
|
||||
).total_seconds()
|
||||
/ 60
|
||||
if job.start_time
|
||||
else 0,
|
||||
"job/retry_count_at_failure": job.retry_count,
|
||||
}
|
||||
)
|
||||
|
||||
# Retry if configured and retries available
|
||||
if (
|
||||
self.config.retry_failed_runs
|
||||
@@ -635,7 +862,13 @@ class SweepManager:
|
||||
"all_models",
|
||||
"final_summary",
|
||||
0,
|
||||
config=report["sweep_config"],
|
||||
config={
|
||||
**report["sweep_config"],
|
||||
"sweep_id": self.sweep_id,
|
||||
"is_resume": self.is_resuming,
|
||||
"sweep_completion_status": "completed",
|
||||
},
|
||||
tags=[f"sweep:{self.sweep_id}", "sweep_summary"],
|
||||
)
|
||||
|
||||
summary_logger.log_sweep_summary(
|
||||
@@ -649,6 +882,25 @@ class SweepManager:
|
||||
# Log model comparison table
|
||||
summary_logger.log_model_comparison_table(results_by_model)
|
||||
|
||||
# Log sweep completion metrics
|
||||
summary_logger.log_metrics(
|
||||
{
|
||||
"sweep/completion_status": "completed",
|
||||
"sweep/is_resume": self.is_resuming,
|
||||
"sweep/total_completed_jobs": report["execution_summary"][
|
||||
"completed_jobs"
|
||||
],
|
||||
"sweep/total_failed_jobs": report["execution_summary"][
|
||||
"failed_jobs"
|
||||
],
|
||||
"sweep/success_rate": report["execution_summary"][
|
||||
"completed_jobs"
|
||||
]
|
||||
/ max(report["execution_summary"]["total_jobs"], 1),
|
||||
"sweep/final_timestamp": datetime.now().timestamp(),
|
||||
}
|
||||
)
|
||||
|
||||
# Save report to file if output directory specified
|
||||
if self.config.output_dir:
|
||||
report_path = (
|
||||
|
Reference in New Issue
Block a user