diff --git a/fle/commons/db_client.py b/fle/commons/db_client.py index 080af71d..43d5d459 100644 --- a/fle/commons/db_client.py +++ b/fle/commons/db_client.py @@ -862,3 +862,11 @@ async def create_db_client( ) else: raise Exception(f"Invalid database type: {db_type}") + + +async def get_next_version() -> int: + """Get next available version number""" + db_client = await create_db_client() + version = await db_client.get_largest_version() + await db_client.cleanup() + return version + 1 \ No newline at end of file diff --git a/fle/env/gym_env/run_eval.py b/fle/env/gym_env/run_eval.py index 3fd5aad2..659aa150 100644 --- a/fle/env/gym_env/run_eval.py +++ b/fle/env/gym_env/run_eval.py @@ -14,8 +14,7 @@ from fle.env.gym_env.trajectory_runner import GymTrajectoryRunner from fle.agents.gym_agent import GymAgent from fle.commons.cluster_ips import get_local_container_ips -from fle.commons.db_client import create_db_client -from fle.eval.algorithms.independent import get_next_version +from fle.commons.db_client import create_db_client, get_next_version from fle.eval.tasks import TaskFactory from fle.env.utils.controller_loader.system_prompt_generator import ( SystemPromptGenerator, @@ -43,7 +42,7 @@ def get_validated_run_configs(run_config_location: str) -> list[GymRunConfig]: ips, udp_ports, tcp_ports = get_local_container_ips() if len(tcp_ports) < len(run_configs): raise ValueError( - f"Not enough containers for {len(run_configs)} runs. Only {len(ips)} containers available." + f"Not enough containers for {len(run_configs)} runs. Only {len(tcp_ports)} containers available." ) return run_configs