mirror of
https://github.com/JackHopkins/factorio-learning-environment.git
synced 2025-09-06 13:23:58 +00:00
other files
This commit is contained in:
@@ -165,6 +165,38 @@ manager = SweepManager(config)
|
|||||||
results = await manager.run_sweep()
|
results = await manager.run_sweep()
|
||||||
```
|
```
|
||||||
|
|
||||||
|
#### Resuming Failed Sweeps
|
||||||
|
|
||||||
|
The SweepManager now supports resuming failed sweeps without duplicating completed runs:
|
||||||
|
|
||||||
|
```python
|
||||||
|
# Resume an existing sweep
|
||||||
|
existing_sweep_id = "my_experiment_20241201_120000_abcd1234"
|
||||||
|
manager = SweepManager(config, existing_sweep_id=existing_sweep_id)
|
||||||
|
|
||||||
|
# Alternative: using class method
|
||||||
|
manager = SweepManager.resume_sweep(config, existing_sweep_id)
|
||||||
|
|
||||||
|
# The manager will automatically:
|
||||||
|
# - Skip completed runs
|
||||||
|
# - Retry partial/failed runs
|
||||||
|
# - Continue with remaining jobs
|
||||||
|
results = await manager.run_sweep()
|
||||||
|
```
|
||||||
|
|
||||||
|
Enhanced WandB metadata for resumed sweeps includes:
|
||||||
|
|
||||||
|
- `sweep_id`: Unique identifier for the sweep
|
||||||
|
- `is_resume`: Boolean indicating if this is a resumed sweep
|
||||||
|
- `completion_status`: Track run completion ("running", "successful", "failed_final", "will_retry")
|
||||||
|
- `retry_count`: Number of retry attempts for each job
|
||||||
|
- Tags: Include `sweep:{sweep_id}` for easy filtering
|
||||||
|
|
||||||
|
Filter in WandB using:
|
||||||
|
- `config.sweep_id = "your_sweep_id"`
|
||||||
|
- `config.completion_status = "successful"`
|
||||||
|
- `tags: sweep:your_sweep_id`
|
||||||
|
|
||||||
### ResultsVisualizer
|
### ResultsVisualizer
|
||||||
|
|
||||||
Generate analysis plots:
|
Generate analysis plots:
|
||||||
@@ -197,6 +229,7 @@ See the `examples/` directory for complete usage examples:
|
|||||||
|
|
||||||
- `example_sweep_config.py`: Example sweep configurations
|
- `example_sweep_config.py`: Example sweep configurations
|
||||||
- `analyze_sweep_results.py`: Analysis script with various commands
|
- `analyze_sweep_results.py`: Analysis script with various commands
|
||||||
|
- `resume_sweep_example.py`: Example of resuming failed sweeps
|
||||||
|
|
||||||
### Running Examples
|
### Running Examples
|
||||||
|
|
||||||
|
@@ -132,7 +132,8 @@ class DatabaseAnalyzer:
|
|||||||
THEN achievements_json
|
THEN achievements_json
|
||||||
ELSE NULL END,
|
ELSE NULL END,
|
||||||
'; '
|
'; '
|
||||||
) as all_achievements
|
) as all_achievements,
|
||||||
|
STRING_AGG(DISTINCT meta::text, '; ') as meta_aggregated
|
||||||
FROM programs
|
FROM programs
|
||||||
WHERE version IN ({version_list})
|
WHERE version IN ({version_list})
|
||||||
GROUP BY version, version_description, model, instance
|
GROUP BY version, version_description, model, instance
|
||||||
@@ -147,6 +148,53 @@ class DatabaseAnalyzer:
|
|||||||
results = await self.db_client.execute_query(query)
|
results = await self.db_client.execute_query(query)
|
||||||
return pd.DataFrame(results)
|
return pd.DataFrame(results)
|
||||||
|
|
||||||
|
async def get_trajectory_summaries_by_sweep(self, sweep_id: str) -> pd.DataFrame:
|
||||||
|
"""Get trajectory-level summaries for a specific sweep
|
||||||
|
|
||||||
|
Args:
|
||||||
|
sweep_id: Sweep identifier to filter by
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
DataFrame with one row per trajectory (version + instance)
|
||||||
|
"""
|
||||||
|
await self.ensure_connection()
|
||||||
|
|
||||||
|
query = """
|
||||||
|
WITH trajectory_stats AS (
|
||||||
|
SELECT
|
||||||
|
version,
|
||||||
|
version_description,
|
||||||
|
model,
|
||||||
|
instance,
|
||||||
|
MAX(depth) as max_depth,
|
||||||
|
MAX(value) as final_reward,
|
||||||
|
MAX(raw_reward) as max_raw_reward,
|
||||||
|
COUNT(*) as num_steps,
|
||||||
|
SUM(token_usage) as total_tokens,
|
||||||
|
SUM(completion_token_usage) as total_completion_tokens,
|
||||||
|
MIN(created_at) as start_time,
|
||||||
|
MAX(created_at) as end_time,
|
||||||
|
STRING_AGG(
|
||||||
|
CASE WHEN achievements_json != '{{}}' AND achievements_json IS NOT NULL
|
||||||
|
THEN achievements_json
|
||||||
|
ELSE NULL END,
|
||||||
|
'; '
|
||||||
|
) as all_achievements,
|
||||||
|
STRING_AGG(DISTINCT meta::text, '; ') as meta_aggregated
|
||||||
|
FROM programs
|
||||||
|
WHERE meta->>'sweep_id' = %s
|
||||||
|
GROUP BY version, version_description, model, instance
|
||||||
|
)
|
||||||
|
SELECT
|
||||||
|
*,
|
||||||
|
EXTRACT(EPOCH FROM (end_time - start_time)) as duration_seconds
|
||||||
|
FROM trajectory_stats
|
||||||
|
ORDER BY version, instance
|
||||||
|
"""
|
||||||
|
|
||||||
|
results = await self.db_client.execute_query(query, (sweep_id,))
|
||||||
|
return pd.DataFrame(results)
|
||||||
|
|
||||||
async def get_model_comparison(
|
async def get_model_comparison(
|
||||||
self,
|
self,
|
||||||
models: List[str],
|
models: List[str],
|
||||||
|
63
fle/eval/analysis/examples/resume_sweep_example.py
Normal file
63
fle/eval/analysis/examples/resume_sweep_example.py
Normal file
@@ -0,0 +1,63 @@
|
|||||||
|
"""
|
||||||
|
Example demonstrating how to resume a failed sweep.
|
||||||
|
|
||||||
|
This example shows how to restart a sweep without duplicating completed runs.
|
||||||
|
"""
|
||||||
|
|
||||||
|
from fle.eval.analysis.sweep_manager import SweepManager, SweepConfig
|
||||||
|
|
||||||
|
|
||||||
|
async def resume_sweep_example():
|
||||||
|
"""Example of resuming a failed sweep"""
|
||||||
|
|
||||||
|
# Create the same config as the original sweep
|
||||||
|
config = SweepConfig(
|
||||||
|
name="my_experiment",
|
||||||
|
models=["gpt-4", "claude-3-sonnet"],
|
||||||
|
tasks=["craft_iron_plate", "build_furnace", "mine_coal"],
|
||||||
|
num_trials_per_config=8,
|
||||||
|
max_concurrent_processes=4,
|
||||||
|
api_key_config_file="path/to/api_keys.json",
|
||||||
|
)
|
||||||
|
|
||||||
|
# Resume sweep with existing sweep ID
|
||||||
|
# You can get the sweep ID from the original run logs or WandB
|
||||||
|
existing_sweep_id = "my_experiment_20241201_120000_abcd1234"
|
||||||
|
|
||||||
|
# Method 1: Using the constructor
|
||||||
|
sweep_manager = SweepManager(config, existing_sweep_id=existing_sweep_id)
|
||||||
|
|
||||||
|
# Method 2: Using the class method (equivalent)
|
||||||
|
# sweep_manager = SweepManager.resume_sweep(config, existing_sweep_id)
|
||||||
|
|
||||||
|
print("Starting sweep resume...")
|
||||||
|
results = await sweep_manager.run_sweep()
|
||||||
|
|
||||||
|
print(f"Sweep completed! Results: {results}")
|
||||||
|
|
||||||
|
|
||||||
|
async def new_sweep_example():
|
||||||
|
"""Example of starting a new sweep (for comparison)"""
|
||||||
|
|
||||||
|
config = SweepConfig(
|
||||||
|
name="my_new_experiment",
|
||||||
|
models=["gpt-4", "claude-3-sonnet"],
|
||||||
|
tasks=["craft_iron_plate", "build_furnace"],
|
||||||
|
num_trials_per_config=4,
|
||||||
|
)
|
||||||
|
|
||||||
|
# Create new sweep (no existing_sweep_id)
|
||||||
|
sweep_manager = SweepManager(config)
|
||||||
|
|
||||||
|
await sweep_manager.run_sweep()
|
||||||
|
print(f"New sweep completed! Sweep ID: {sweep_manager.sweep_id}")
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
# To resume a sweep:
|
||||||
|
# asyncio.run(resume_sweep_example())
|
||||||
|
|
||||||
|
# To start a new sweep:
|
||||||
|
# asyncio.run(new_sweep_example())
|
||||||
|
|
||||||
|
print("Example script - uncomment the appropriate line above to run")
|
338
fle/eval/analysis/server_manager.py
Normal file
338
fle/eval/analysis/server_manager.py
Normal file
@@ -0,0 +1,338 @@
|
|||||||
|
"""
|
||||||
|
Server management for concurrent evaluations.
|
||||||
|
|
||||||
|
Handles allocation and tracking of Factorio server instances across
|
||||||
|
multiple concurrent evaluation processes.
|
||||||
|
"""
|
||||||
|
|
||||||
|
import threading
|
||||||
|
from dataclasses import dataclass
|
||||||
|
from typing import List, Dict, Optional, Set
|
||||||
|
from datetime import datetime, timedelta
|
||||||
|
|
||||||
|
from fle.commons.cluster_ips import get_local_container_ips
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass
|
||||||
|
class ServerAllocation:
|
||||||
|
"""Tracks allocation of a Factorio server to a job"""
|
||||||
|
|
||||||
|
server_id: int
|
||||||
|
server_address: str
|
||||||
|
tcp_port: int
|
||||||
|
udp_port: int
|
||||||
|
job_id: str
|
||||||
|
allocated_at: datetime
|
||||||
|
process_id: Optional[int] = None
|
||||||
|
|
||||||
|
def to_dict(self) -> Dict:
|
||||||
|
"""Convert to dictionary for logging"""
|
||||||
|
return {
|
||||||
|
"server_id": self.server_id,
|
||||||
|
"server_address": self.server_address,
|
||||||
|
"tcp_port": self.tcp_port,
|
||||||
|
"udp_port": self.udp_port,
|
||||||
|
"job_id": self.job_id,
|
||||||
|
"allocated_at": self.allocated_at.isoformat(),
|
||||||
|
"process_id": self.process_id,
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
class ServerManager:
|
||||||
|
"""Manages allocation of Factorio servers across concurrent jobs"""
|
||||||
|
|
||||||
|
def __init__(self, max_allocation_time_hours: float = 2.0):
|
||||||
|
"""Initialize server manager
|
||||||
|
|
||||||
|
Args:
|
||||||
|
max_allocation_time_hours: Maximum time to hold a server allocation
|
||||||
|
"""
|
||||||
|
self._lock = threading.Lock()
|
||||||
|
self._allocations: Dict[int, ServerAllocation] = {}
|
||||||
|
self._available_servers: List[int] = []
|
||||||
|
self._allocated_servers: Set[int] = set()
|
||||||
|
self.max_allocation_time = timedelta(hours=max_allocation_time_hours)
|
||||||
|
self._initialized = False
|
||||||
|
|
||||||
|
def _discover_servers(self) -> bool:
|
||||||
|
"""Discover available Factorio servers
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
True if servers were found, False otherwise
|
||||||
|
"""
|
||||||
|
try:
|
||||||
|
ips, udp_ports, tcp_ports = get_local_container_ips()
|
||||||
|
|
||||||
|
if not tcp_ports:
|
||||||
|
print("⚠️ No Factorio containers found")
|
||||||
|
return False
|
||||||
|
|
||||||
|
self._available_servers = list(range(len(tcp_ports)))
|
||||||
|
self._server_info = {
|
||||||
|
i: {
|
||||||
|
"address": ips[i],
|
||||||
|
"tcp_port": tcp_ports[i],
|
||||||
|
"udp_port": udp_ports[i],
|
||||||
|
}
|
||||||
|
for i in range(len(tcp_ports))
|
||||||
|
}
|
||||||
|
|
||||||
|
print(f"🖥️ Discovered {len(tcp_ports)} Factorio servers:")
|
||||||
|
for i, (ip, tcp_port) in enumerate(zip(ips, tcp_ports)):
|
||||||
|
print(f" Server {i}: {ip}:{tcp_port}")
|
||||||
|
|
||||||
|
self._initialized = True
|
||||||
|
return True
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
print(f"❌ Error discovering servers: {e}")
|
||||||
|
return False
|
||||||
|
|
||||||
|
def initialize(self) -> bool:
|
||||||
|
"""Initialize server discovery
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
True if initialization was successful
|
||||||
|
"""
|
||||||
|
with self._lock:
|
||||||
|
if not self._initialized:
|
||||||
|
return self._discover_servers()
|
||||||
|
return True
|
||||||
|
|
||||||
|
def get_available_server_count(self) -> int:
|
||||||
|
"""Get number of currently available servers"""
|
||||||
|
with self._lock:
|
||||||
|
if not self._initialized:
|
||||||
|
self._discover_servers()
|
||||||
|
return len(self._available_servers) - len(self._allocated_servers)
|
||||||
|
|
||||||
|
def get_total_server_count(self) -> int:
|
||||||
|
"""Get total number of discovered servers"""
|
||||||
|
with self._lock:
|
||||||
|
if not self._initialized:
|
||||||
|
self._discover_servers()
|
||||||
|
return len(self._available_servers)
|
||||||
|
|
||||||
|
def allocate_server(
|
||||||
|
self, job_id: str, process_id: Optional[int] = None
|
||||||
|
) -> Optional[ServerAllocation]:
|
||||||
|
"""Allocate a server for a job
|
||||||
|
|
||||||
|
Args:
|
||||||
|
job_id: Unique identifier for the job
|
||||||
|
process_id: Optional process ID for tracking
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
ServerAllocation if successful, None if no servers available
|
||||||
|
"""
|
||||||
|
with self._lock:
|
||||||
|
# Initialize if needed
|
||||||
|
if not self._initialized:
|
||||||
|
if not self._discover_servers():
|
||||||
|
return None
|
||||||
|
|
||||||
|
# Clean up expired allocations
|
||||||
|
self._cleanup_expired_allocations()
|
||||||
|
|
||||||
|
# Find available server
|
||||||
|
available_servers = [
|
||||||
|
server_id
|
||||||
|
for server_id in self._available_servers
|
||||||
|
if server_id not in self._allocated_servers
|
||||||
|
]
|
||||||
|
|
||||||
|
if not available_servers:
|
||||||
|
print(
|
||||||
|
f"⚠️ No servers available for job {job_id} (all {len(self._available_servers)} servers allocated)"
|
||||||
|
)
|
||||||
|
return None
|
||||||
|
|
||||||
|
# Allocate first available server
|
||||||
|
server_id = available_servers[0]
|
||||||
|
server_info = self._server_info[server_id]
|
||||||
|
|
||||||
|
allocation = ServerAllocation(
|
||||||
|
server_id=server_id,
|
||||||
|
server_address=server_info["address"],
|
||||||
|
tcp_port=server_info["tcp_port"],
|
||||||
|
udp_port=server_info["udp_port"],
|
||||||
|
job_id=job_id,
|
||||||
|
allocated_at=datetime.now(),
|
||||||
|
process_id=process_id,
|
||||||
|
)
|
||||||
|
|
||||||
|
self._allocations[server_id] = allocation
|
||||||
|
self._allocated_servers.add(server_id)
|
||||||
|
|
||||||
|
print(
|
||||||
|
f"🖥️ Allocated server {server_id} ({allocation.server_address}:{allocation.tcp_port}) to job {job_id}"
|
||||||
|
)
|
||||||
|
|
||||||
|
return allocation
|
||||||
|
|
||||||
|
def release_server(self, job_id: str) -> bool:
|
||||||
|
"""Release server allocation for a job
|
||||||
|
|
||||||
|
Args:
|
||||||
|
job_id: Job identifier to release
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
True if server was found and released
|
||||||
|
"""
|
||||||
|
with self._lock:
|
||||||
|
# Find allocation by job_id
|
||||||
|
server_id = None
|
||||||
|
for sid, allocation in self._allocations.items():
|
||||||
|
if allocation.job_id == job_id:
|
||||||
|
server_id = sid
|
||||||
|
break
|
||||||
|
|
||||||
|
if server_id is not None:
|
||||||
|
allocation = self._allocations.pop(server_id)
|
||||||
|
self._allocated_servers.remove(server_id)
|
||||||
|
print(f"🔓 Released server {server_id} from job {job_id}")
|
||||||
|
return True
|
||||||
|
|
||||||
|
return False
|
||||||
|
|
||||||
|
def release_server_by_id(self, server_id: int) -> bool:
|
||||||
|
"""Release server allocation by server ID
|
||||||
|
|
||||||
|
Args:
|
||||||
|
server_id: Server ID to release
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
True if server was found and released
|
||||||
|
"""
|
||||||
|
with self._lock:
|
||||||
|
if server_id in self._allocations:
|
||||||
|
allocation = self._allocations.pop(server_id)
|
||||||
|
self._allocated_servers.remove(server_id)
|
||||||
|
print(
|
||||||
|
f"🔓 Released server {server_id} (was allocated to {allocation.job_id})"
|
||||||
|
)
|
||||||
|
return True
|
||||||
|
|
||||||
|
return False
|
||||||
|
|
||||||
|
def _cleanup_expired_allocations(self):
|
||||||
|
"""Clean up allocations that have been held too long (called with lock held)"""
|
||||||
|
current_time = datetime.now()
|
||||||
|
expired_servers = []
|
||||||
|
|
||||||
|
for server_id, allocation in self._allocations.items():
|
||||||
|
if current_time - allocation.allocated_at > self.max_allocation_time:
|
||||||
|
expired_servers.append(server_id)
|
||||||
|
|
||||||
|
for server_id in expired_servers:
|
||||||
|
allocation = self._allocations.pop(server_id)
|
||||||
|
self._allocated_servers.remove(server_id)
|
||||||
|
print(
|
||||||
|
f"⏰ Released expired allocation: server {server_id} (was allocated to {allocation.job_id})"
|
||||||
|
)
|
||||||
|
|
||||||
|
def get_allocation_status(self) -> Dict:
|
||||||
|
"""Get current allocation status
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Dictionary with allocation information
|
||||||
|
"""
|
||||||
|
with self._lock:
|
||||||
|
if not self._initialized:
|
||||||
|
self._discover_servers()
|
||||||
|
|
||||||
|
self._cleanup_expired_allocations()
|
||||||
|
|
||||||
|
return {
|
||||||
|
"total_servers": len(self._available_servers),
|
||||||
|
"allocated_servers": len(self._allocated_servers),
|
||||||
|
"available_servers": len(self._available_servers)
|
||||||
|
- len(self._allocated_servers),
|
||||||
|
"allocations": [
|
||||||
|
allocation.to_dict() for allocation in self._allocations.values()
|
||||||
|
],
|
||||||
|
"initialized": self._initialized,
|
||||||
|
}
|
||||||
|
|
||||||
|
def get_server_assignment_for_job(self, job_id: str) -> Optional[Dict]:
|
||||||
|
"""Get server assignment for a specific job
|
||||||
|
|
||||||
|
Args:
|
||||||
|
job_id: Job identifier
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Dictionary with server info or None if not found
|
||||||
|
"""
|
||||||
|
with self._lock:
|
||||||
|
for allocation in self._allocations.values():
|
||||||
|
if allocation.job_id == job_id:
|
||||||
|
return {
|
||||||
|
"server_id": allocation.server_id,
|
||||||
|
"address": allocation.server_address,
|
||||||
|
"tcp_port": allocation.tcp_port,
|
||||||
|
"udp_port": allocation.udp_port,
|
||||||
|
"allocated_at": allocation.allocated_at.isoformat(),
|
||||||
|
}
|
||||||
|
return None
|
||||||
|
|
||||||
|
def force_release_all(self):
|
||||||
|
"""Force release all server allocations (emergency cleanup)"""
|
||||||
|
with self._lock:
|
||||||
|
released_count = len(self._allocations)
|
||||||
|
self._allocations.clear()
|
||||||
|
self._allocated_servers.clear()
|
||||||
|
|
||||||
|
if released_count > 0:
|
||||||
|
print(f"🧹 Force released all {released_count} server allocations")
|
||||||
|
|
||||||
|
def print_status(self):
|
||||||
|
"""Print current server allocation status"""
|
||||||
|
status = self.get_allocation_status()
|
||||||
|
|
||||||
|
print("🖥️ Server Allocation Status:")
|
||||||
|
print(f" Total servers: {status['total_servers']}")
|
||||||
|
print(f" Available: {status['available_servers']}")
|
||||||
|
print(f" Allocated: {status['allocated_servers']}")
|
||||||
|
|
||||||
|
if status["allocations"]:
|
||||||
|
print(" Current allocations:")
|
||||||
|
for alloc in status["allocations"]:
|
||||||
|
print(
|
||||||
|
f" Server {alloc['server_id']}: {alloc['job_id']} "
|
||||||
|
f"(since {alloc['allocated_at'][:19]})"
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
# Global server manager instance
|
||||||
|
_global_server_manager: Optional[ServerManager] = None
|
||||||
|
|
||||||
|
|
||||||
|
def get_server_manager() -> ServerManager:
|
||||||
|
"""Get or create global server manager instance"""
|
||||||
|
global _global_server_manager
|
||||||
|
|
||||||
|
if _global_server_manager is None:
|
||||||
|
_global_server_manager = ServerManager()
|
||||||
|
_global_server_manager.initialize()
|
||||||
|
|
||||||
|
return _global_server_manager
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
# Test server manager
|
||||||
|
manager = ServerManager()
|
||||||
|
manager.initialize()
|
||||||
|
manager.print_status()
|
||||||
|
|
||||||
|
# Test allocation
|
||||||
|
alloc1 = manager.allocate_server("test_job_1")
|
||||||
|
alloc2 = manager.allocate_server("test_job_2")
|
||||||
|
|
||||||
|
manager.print_status()
|
||||||
|
|
||||||
|
if alloc1:
|
||||||
|
manager.release_server("test_job_1")
|
||||||
|
if alloc2:
|
||||||
|
manager.release_server("test_job_2")
|
||||||
|
|
||||||
|
manager.print_status()
|
@@ -88,20 +88,31 @@ class RunJob:
|
|||||||
|
|
||||||
|
|
||||||
class SweepManager:
|
class SweepManager:
|
||||||
"""Manages large-scale evaluation sweeps"""
|
"""Manages large-scale evaluation sweeps
|
||||||
|
|
||||||
def __init__(self, config: SweepConfig):
|
Supports both new sweeps and resuming existing ones. When resuming, the manager
|
||||||
|
will automatically skip completed runs and retry partial/failed runs.
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(self, config: SweepConfig, existing_sweep_id: Optional[str] = None):
|
||||||
"""Initialize sweep manager
|
"""Initialize sweep manager
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
config: SweepConfig with sweep parameters
|
config: SweepConfig with sweep parameters
|
||||||
|
existing_sweep_id: Optional sweep ID to resume an existing sweep
|
||||||
"""
|
"""
|
||||||
self.config = config
|
self.config = config
|
||||||
|
self.is_resuming = existing_sweep_id is not None
|
||||||
|
|
||||||
# Generate unique sweep ID with timestamp and short UUID
|
if existing_sweep_id:
|
||||||
timestamp = datetime.now().strftime("%Y%m%d_%H%M%S")
|
self.sweep_id = existing_sweep_id
|
||||||
short_uuid = str(uuid.uuid4())[:8]
|
print(f"🔄 Resuming sweep: {self.sweep_id}")
|
||||||
self.sweep_id = f"{config.name}_{timestamp}_{short_uuid}"
|
else:
|
||||||
|
# Generate unique sweep ID with timestamp and short UUID
|
||||||
|
timestamp = datetime.now().strftime("%Y%m%d_%H%M%S")
|
||||||
|
short_uuid = str(uuid.uuid4())[:8]
|
||||||
|
self.sweep_id = f"{config.name}_{timestamp}_{short_uuid}"
|
||||||
|
print(f"🆕 Starting new sweep: {self.sweep_id}")
|
||||||
|
|
||||||
self.jobs: List[RunJob] = []
|
self.jobs: List[RunJob] = []
|
||||||
self.active_processes: Dict[str, multiprocessing.Process] = {}
|
self.active_processes: Dict[str, multiprocessing.Process] = {}
|
||||||
@@ -139,18 +150,155 @@ class SweepManager:
|
|||||||
f"🖥️ Available servers: {total_servers}, Max concurrent: {self.config.max_concurrent_processes}"
|
f"🖥️ Available servers: {total_servers}, Max concurrent: {self.config.max_concurrent_processes}"
|
||||||
)
|
)
|
||||||
|
|
||||||
def generate_jobs(self) -> List[RunJob]:
|
# If resuming, we'll load existing state in run_sweep()
|
||||||
"""Generate all run jobs for the sweep
|
|
||||||
|
@classmethod
|
||||||
|
def resume_sweep(cls, config: SweepConfig, sweep_id: str) -> "SweepManager":
|
||||||
|
"""Create a SweepManager to resume an existing sweep
|
||||||
|
|
||||||
|
Args:
|
||||||
|
config: SweepConfig with sweep parameters (should match original config)
|
||||||
|
sweep_id: ID of the existing sweep to resume
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
List of RunJob objects
|
SweepManager instance configured to resume the sweep
|
||||||
|
"""
|
||||||
|
return cls(config, existing_sweep_id=sweep_id)
|
||||||
|
|
||||||
|
async def load_existing_sweep_state(self) -> Dict[str, Any]:
|
||||||
|
"""Load existing sweep state from database
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Dictionary with information about completed runs
|
||||||
|
"""
|
||||||
|
if not self.is_resuming:
|
||||||
|
return {"completed_runs": {}, "partial_runs": {}}
|
||||||
|
|
||||||
|
print(f"🔍 Loading existing state for sweep: {self.sweep_id}")
|
||||||
|
|
||||||
|
# Initialize database analyzer if not already done
|
||||||
|
if not self.database_analyzer:
|
||||||
|
self.database_analyzer = DatabaseAnalyzer()
|
||||||
|
|
||||||
|
try:
|
||||||
|
# Get existing trajectories for this sweep
|
||||||
|
existing_df = (
|
||||||
|
await self.database_analyzer.get_trajectory_summaries_by_sweep(
|
||||||
|
self.sweep_id
|
||||||
|
)
|
||||||
|
)
|
||||||
|
|
||||||
|
if existing_df.empty:
|
||||||
|
print(f"⚠️ No existing data found for sweep {self.sweep_id}")
|
||||||
|
return {"completed_runs": {}, "partial_runs": {}}
|
||||||
|
|
||||||
|
print(f"📊 Found {len(existing_df)} existing trajectories")
|
||||||
|
|
||||||
|
completed_runs = {} # key: (model, task, trial), value: version
|
||||||
|
partial_runs = {} # key: (model, task, trial), value: {version, status}
|
||||||
|
|
||||||
|
# Group by version to identify complete vs partial runs
|
||||||
|
for _, row in existing_df.iterrows():
|
||||||
|
model = row["model"]
|
||||||
|
version_desc = row["version_description"]
|
||||||
|
version = row["version"]
|
||||||
|
instance = row["instance"] # This is the trial number
|
||||||
|
|
||||||
|
# Extract task from version_description
|
||||||
|
# Assuming version_description contains "type:task_name"
|
||||||
|
task = "unknown_task"
|
||||||
|
if version_desc and "type:" in version_desc:
|
||||||
|
task = version_desc.split("type:")[1].split("\n")[0].strip()
|
||||||
|
elif version_desc:
|
||||||
|
# Fallback: use first line or part of version_description
|
||||||
|
task = version_desc.split("\n")[0].strip()
|
||||||
|
|
||||||
|
# Key for identifying unique run
|
||||||
|
run_key = (model, task, instance)
|
||||||
|
|
||||||
|
# Check if this is a complete trajectory
|
||||||
|
# A trajectory is considered complete if it has steps and ended properly
|
||||||
|
num_steps = row.get("num_steps", 0)
|
||||||
|
final_reward = row.get("final_reward", 0)
|
||||||
|
|
||||||
|
# Consider a run complete if it has reasonable number of steps
|
||||||
|
# This is a heuristic - might need adjustment based on typical run patterns
|
||||||
|
is_complete = num_steps > 5 # Adjust this threshold as needed
|
||||||
|
|
||||||
|
if is_complete:
|
||||||
|
completed_runs[run_key] = version
|
||||||
|
else:
|
||||||
|
partial_runs[run_key] = {
|
||||||
|
"version": version,
|
||||||
|
"num_steps": num_steps,
|
||||||
|
"final_reward": final_reward,
|
||||||
|
}
|
||||||
|
|
||||||
|
print(f"✅ Found {len(completed_runs)} completed runs")
|
||||||
|
print(f"⚠️ Found {len(partial_runs)} partial/incomplete runs")
|
||||||
|
|
||||||
|
# Log some examples
|
||||||
|
if completed_runs:
|
||||||
|
print("Examples of completed runs:")
|
||||||
|
for i, (model, task, trial) in enumerate(
|
||||||
|
list(completed_runs.keys())[:3]
|
||||||
|
):
|
||||||
|
print(f" - {model} on {task} trial {trial}")
|
||||||
|
|
||||||
|
if partial_runs:
|
||||||
|
print("Examples of partial runs (will be retried):")
|
||||||
|
for i, (model, task, trial) in enumerate(list(partial_runs.keys())[:3]):
|
||||||
|
info = partial_runs[(model, task, trial)]
|
||||||
|
print(
|
||||||
|
f" - {model} on {task} trial {trial} ({info['num_steps']} steps)"
|
||||||
|
)
|
||||||
|
|
||||||
|
return {
|
||||||
|
"completed_runs": completed_runs,
|
||||||
|
"partial_runs": partial_runs,
|
||||||
|
"total_existing": len(existing_df),
|
||||||
|
}
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
print(f"❌ Error loading existing sweep state: {e}")
|
||||||
|
return {"completed_runs": {}, "partial_runs": {}}
|
||||||
|
|
||||||
|
def generate_jobs(
|
||||||
|
self, existing_state: Optional[Dict[str, Any]] = None
|
||||||
|
) -> List[RunJob]:
|
||||||
|
"""Generate all run jobs for the sweep, excluding completed ones if resuming
|
||||||
|
|
||||||
|
Args:
|
||||||
|
existing_state: Optional dictionary with completed and partial run information
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
List of RunJob objects to execute
|
||||||
"""
|
"""
|
||||||
jobs = []
|
jobs = []
|
||||||
|
completed_runs = (
|
||||||
|
existing_state.get("completed_runs", {}) if existing_state else {}
|
||||||
|
)
|
||||||
|
partial_runs = existing_state.get("partial_runs", {}) if existing_state else {}
|
||||||
|
|
||||||
|
skipped_count = 0
|
||||||
|
retry_count = 0
|
||||||
|
|
||||||
# Generate all combinations of models and tasks
|
# Generate all combinations of models and tasks
|
||||||
for model, task in itertools.product(self.config.models, self.config.tasks):
|
for model, task in itertools.product(self.config.models, self.config.tasks):
|
||||||
for trial in range(self.config.num_trials_per_config):
|
for trial in range(self.config.num_trials_per_config):
|
||||||
|
run_key = (model, task, trial)
|
||||||
job_id = f"{model}_{task}_trial{trial:02d}"
|
job_id = f"{model}_{task}_trial{trial:02d}"
|
||||||
|
|
||||||
|
# Skip if already completed
|
||||||
|
if run_key in completed_runs:
|
||||||
|
skipped_count += 1
|
||||||
|
# Still track these for completed_versions list
|
||||||
|
existing_version = completed_runs[run_key]
|
||||||
|
if existing_version not in self.completed_versions:
|
||||||
|
self.completed_versions.append(existing_version)
|
||||||
|
continue
|
||||||
|
|
||||||
|
# Create job (will run for new jobs and retry for partial jobs)
|
||||||
job = RunJob(
|
job = RunJob(
|
||||||
job_id=job_id,
|
job_id=job_id,
|
||||||
model=model,
|
model=model,
|
||||||
@@ -158,6 +306,14 @@ class SweepManager:
|
|||||||
trial_number=trial,
|
trial_number=trial,
|
||||||
sweep_id=self.sweep_id,
|
sweep_id=self.sweep_id,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
# If this is a partial run, mark it for retry
|
||||||
|
if run_key in partial_runs:
|
||||||
|
retry_count += 1
|
||||||
|
job.retry_count = 1 # Mark that this is a retry
|
||||||
|
job.status = "pending" # Will be retried
|
||||||
|
print(f"🔄 Will retry partial run: {job_id}")
|
||||||
|
|
||||||
jobs.append(job)
|
jobs.append(job)
|
||||||
|
|
||||||
# Shuffle if requested to distribute load
|
# Shuffle if requested to distribute load
|
||||||
@@ -165,6 +321,20 @@ class SweepManager:
|
|||||||
random.shuffle(jobs)
|
random.shuffle(jobs)
|
||||||
|
|
||||||
self.jobs = jobs
|
self.jobs = jobs
|
||||||
|
|
||||||
|
# Log summary
|
||||||
|
total_expected = (
|
||||||
|
len(self.config.models)
|
||||||
|
* len(self.config.tasks)
|
||||||
|
* self.config.num_trials_per_config
|
||||||
|
)
|
||||||
|
print("📋 Job generation summary:")
|
||||||
|
print(f" Total expected jobs: {total_expected}")
|
||||||
|
print(f" Already completed: {skipped_count}")
|
||||||
|
print(f" Will retry (partial): {retry_count}")
|
||||||
|
print(f" New jobs to run: {len(jobs) - retry_count}")
|
||||||
|
print(f" Total jobs to execute: {len(jobs)}")
|
||||||
|
|
||||||
return jobs
|
return jobs
|
||||||
|
|
||||||
async def run_sweep(self) -> Dict[str, Any]:
|
async def run_sweep(self) -> Dict[str, Any]:
|
||||||
@@ -180,11 +350,15 @@ class SweepManager:
|
|||||||
|
|
||||||
self.start_time = datetime.now()
|
self.start_time = datetime.now()
|
||||||
|
|
||||||
# Generate all jobs
|
# Load existing state if resuming
|
||||||
self.generate_jobs()
|
existing_state = await self.load_existing_sweep_state()
|
||||||
|
|
||||||
# Initialize database analyzer for monitoring
|
# Generate all jobs (excluding completed ones if resuming)
|
||||||
self.database_analyzer = DatabaseAnalyzer()
|
self.generate_jobs(existing_state)
|
||||||
|
|
||||||
|
# Initialize database analyzer for monitoring (if not already initialized)
|
||||||
|
if not self.database_analyzer:
|
||||||
|
self.database_analyzer = DatabaseAnalyzer()
|
||||||
|
|
||||||
# Get starting version numbers
|
# Get starting version numbers
|
||||||
base_version = await get_next_version()
|
base_version = await get_next_version()
|
||||||
@@ -293,7 +467,17 @@ class SweepManager:
|
|||||||
# Log to WandB if enabled
|
# Log to WandB if enabled
|
||||||
if self.wandb_logger:
|
if self.wandb_logger:
|
||||||
logger = self.wandb_logger.create_run_logger(
|
logger = self.wandb_logger.create_run_logger(
|
||||||
job.job_id, job.model, job.task, job.version
|
job.job_id,
|
||||||
|
job.model,
|
||||||
|
job.task,
|
||||||
|
job.version,
|
||||||
|
config={
|
||||||
|
"sweep_id": self.sweep_id,
|
||||||
|
"is_resume": self.is_resuming,
|
||||||
|
"retry_count": job.retry_count,
|
||||||
|
"trial_number": job.trial_number,
|
||||||
|
},
|
||||||
|
tags=[f"sweep:{self.sweep_id}"],
|
||||||
)
|
)
|
||||||
logger.log_metrics(
|
logger.log_metrics(
|
||||||
{
|
{
|
||||||
@@ -301,6 +485,10 @@ class SweepManager:
|
|||||||
"job/server_id": server_allocation.server_id,
|
"job/server_id": server_allocation.server_id,
|
||||||
"job/server_address": server_allocation.server_address,
|
"job/server_address": server_allocation.server_address,
|
||||||
"job/server_port": server_allocation.tcp_port,
|
"job/server_port": server_allocation.tcp_port,
|
||||||
|
"job/sweep_id": self.sweep_id,
|
||||||
|
"job/is_resume": self.is_resuming,
|
||||||
|
"job/retry_count": job.retry_count,
|
||||||
|
"job/completion_status": "running",
|
||||||
}
|
}
|
||||||
)
|
)
|
||||||
|
|
||||||
@@ -418,13 +606,52 @@ class SweepManager:
|
|||||||
if self.wandb_logger:
|
if self.wandb_logger:
|
||||||
logger = self.wandb_logger.get_logger(job_id)
|
logger = self.wandb_logger.get_logger(job_id)
|
||||||
if logger:
|
if logger:
|
||||||
logger.log_metrics({"job/status": "completed"})
|
logger.log_metrics(
|
||||||
|
{
|
||||||
|
"job/status": "completed",
|
||||||
|
"job/completion_status": "successful",
|
||||||
|
"job/exit_code": process.exitcode,
|
||||||
|
"job/duration_minutes": (
|
||||||
|
job.end_time - job.start_time
|
||||||
|
).total_seconds()
|
||||||
|
/ 60
|
||||||
|
if job.start_time
|
||||||
|
else 0,
|
||||||
|
"job/final_retry_count": job.retry_count,
|
||||||
|
}
|
||||||
|
)
|
||||||
|
|
||||||
else:
|
else:
|
||||||
job.status = "failed"
|
job.status = "failed"
|
||||||
job.error_message = f"Process exited with code {process.exitcode}"
|
job.error_message = f"Process exited with code {process.exitcode}"
|
||||||
print(f"❌ Job {job_id} failed with exit code {process.exitcode}")
|
print(f"❌ Job {job_id} failed with exit code {process.exitcode}")
|
||||||
|
|
||||||
|
# Log failure to WandB
|
||||||
|
if self.wandb_logger:
|
||||||
|
logger = self.wandb_logger.get_logger(job_id)
|
||||||
|
if logger:
|
||||||
|
will_retry = (
|
||||||
|
self.config.retry_failed_runs
|
||||||
|
and job.retry_count < self.config.max_retries
|
||||||
|
)
|
||||||
|
logger.log_metrics(
|
||||||
|
{
|
||||||
|
"job/status": "failed",
|
||||||
|
"job/completion_status": "will_retry"
|
||||||
|
if will_retry
|
||||||
|
else "failed_final",
|
||||||
|
"job/exit_code": process.exitcode,
|
||||||
|
"job/error_message": job.error_message,
|
||||||
|
"job/duration_minutes": (
|
||||||
|
job.end_time - job.start_time
|
||||||
|
).total_seconds()
|
||||||
|
/ 60
|
||||||
|
if job.start_time
|
||||||
|
else 0,
|
||||||
|
"job/retry_count_at_failure": job.retry_count,
|
||||||
|
}
|
||||||
|
)
|
||||||
|
|
||||||
# Retry if configured and retries available
|
# Retry if configured and retries available
|
||||||
if (
|
if (
|
||||||
self.config.retry_failed_runs
|
self.config.retry_failed_runs
|
||||||
@@ -635,7 +862,13 @@ class SweepManager:
|
|||||||
"all_models",
|
"all_models",
|
||||||
"final_summary",
|
"final_summary",
|
||||||
0,
|
0,
|
||||||
config=report["sweep_config"],
|
config={
|
||||||
|
**report["sweep_config"],
|
||||||
|
"sweep_id": self.sweep_id,
|
||||||
|
"is_resume": self.is_resuming,
|
||||||
|
"sweep_completion_status": "completed",
|
||||||
|
},
|
||||||
|
tags=[f"sweep:{self.sweep_id}", "sweep_summary"],
|
||||||
)
|
)
|
||||||
|
|
||||||
summary_logger.log_sweep_summary(
|
summary_logger.log_sweep_summary(
|
||||||
@@ -649,6 +882,25 @@ class SweepManager:
|
|||||||
# Log model comparison table
|
# Log model comparison table
|
||||||
summary_logger.log_model_comparison_table(results_by_model)
|
summary_logger.log_model_comparison_table(results_by_model)
|
||||||
|
|
||||||
|
# Log sweep completion metrics
|
||||||
|
summary_logger.log_metrics(
|
||||||
|
{
|
||||||
|
"sweep/completion_status": "completed",
|
||||||
|
"sweep/is_resume": self.is_resuming,
|
||||||
|
"sweep/total_completed_jobs": report["execution_summary"][
|
||||||
|
"completed_jobs"
|
||||||
|
],
|
||||||
|
"sweep/total_failed_jobs": report["execution_summary"][
|
||||||
|
"failed_jobs"
|
||||||
|
],
|
||||||
|
"sweep/success_rate": report["execution_summary"][
|
||||||
|
"completed_jobs"
|
||||||
|
]
|
||||||
|
/ max(report["execution_summary"]["total_jobs"], 1),
|
||||||
|
"sweep/final_timestamp": datetime.now().timestamp(),
|
||||||
|
}
|
||||||
|
)
|
||||||
|
|
||||||
# Save report to file if output directory specified
|
# Save report to file if output directory specified
|
||||||
if self.config.output_dir:
|
if self.config.output_dir:
|
||||||
report_path = (
|
report_path = (
|
||||||
|
Reference in New Issue
Block a user