mirror of
https://github.com/JackHopkins/factorio-learning-environment.git
synced 2025-09-06 13:23:58 +00:00
llm_factory (#290)
* first iteration * change to support openai api endpoints * Refactor APIFactory to use OpenAI-compatible endpoints - Unified all providers to use OpenAI client format - Eliminated provider-specific conditional branches - Simplified provider detection using dict ordering - Removed unused parameters and added missing return - 90% reduction in code complexity * Further simplify APIFactory - Remove redundant MODELS_WITH_IMAGE_SUPPORT array - Use provider config supports_images instead - Inline _prepare_messages logic - Extract _get_reasoning_length helper - Add missing default return - 20+ line reduction while maintaining functionality * removecomment * Inline reasoning length logic - Remove _get_reasoning_length helper method - Inline reasoning effort logic in o1/o3 handling - Keep code simpler and more direct * add provider sorting for openrouter to get fastest throughput * add nitro * add usage tracking * usage * undo changes that added logging * update config paths * remove offset * offset * Aug 20, 2025 at 20:25 * fix run_idx port offset * make sure there is keyerror if no port * fix
This commit is contained in:
@@ -16,6 +16,8 @@ FLE_DB_TYPE="sqlite"
|
||||
# SQLite Configuration (used when FLE_DB_TYPE=sqlite or as fallback)
|
||||
# If not set, defaults to .fle/data.db
|
||||
SQLITE_DB_FILE=".fle/data.db"
|
||||
# Offset for the port of the Factorio server when running multiple shells
|
||||
PORT_OFFSET=0
|
||||
|
||||
# PostgreSQL Configuration (only needed when FLE_DB_TYPE=postgres)
|
||||
SKILLS_DB_HOST=XXX
|
||||
|
@@ -1,534 +1,95 @@
|
||||
import asyncio
|
||||
import os
|
||||
|
||||
import anthropic
|
||||
from openai import AsyncOpenAI, OpenAI
|
||||
from openai import AsyncOpenAI
|
||||
from tenacity import retry, wait_exponential
|
||||
|
||||
from fle.agents.llm.metrics import timing_tracker, track_timing_async
|
||||
from fle.agents.llm.utils import (
|
||||
format_messages_for_anthropic,
|
||||
format_messages_for_openai,
|
||||
has_image_content,
|
||||
merge_contiguous_messages,
|
||||
remove_whitespace_blocks,
|
||||
)
|
||||
|
||||
|
||||
class NoRetryAsyncOpenAI(AsyncOpenAI):
|
||||
"""Wrapper around AsyncOpenAI that always sets max_retries=0"""
|
||||
|
||||
def __init__(self, **kwargs):
|
||||
kwargs["max_retries"] = 0
|
||||
super().__init__(**kwargs)
|
||||
|
||||
|
||||
class APIFactory:
|
||||
# Models that support image input
|
||||
MODELS_WITH_IMAGE_SUPPORT = [
|
||||
# Claude models with vision
|
||||
"claude-3-opus",
|
||||
"claude-3-sonnet",
|
||||
"claude-3-haiku",
|
||||
"claude-3-5-sonnet",
|
||||
"claude-3-7-sonnet",
|
||||
"claude-3.7-sonnet",
|
||||
# OpenAI models with vision
|
||||
"gpt-4-vision",
|
||||
"gpt-4-turbo",
|
||||
"gpt-4o",
|
||||
"gpt-4-1106-vision-preview",
|
||||
]
|
||||
# Provider configurations
|
||||
PROVIDERS = {
|
||||
"open-router": {
|
||||
"base_url": "https://openrouter.ai/api/v1",
|
||||
"api_key": "OPEN_ROUTER_API_KEY",
|
||||
"model_transform": lambda m: m.replace("open-router-", ""),
|
||||
},
|
||||
"claude": {
|
||||
"base_url": "https://api.anthropic.com/v1",
|
||||
"api_key": "ANTHROPIC_API_KEY",
|
||||
},
|
||||
"deepseek": {
|
||||
"base_url": "https://api.deepseek.com",
|
||||
"api_key": "DEEPSEEK_API_KEY",
|
||||
},
|
||||
"gemini": {
|
||||
"base_url": "https://generativelanguage.googleapis.com/v1beta/openai/",
|
||||
"api_key": "GEMINI_API_KEY",
|
||||
},
|
||||
"together": {
|
||||
"base_url": "https://api.together.xyz/v1",
|
||||
"api_key": "TOGETHER_API_KEY",
|
||||
},
|
||||
"openai": {
|
||||
"base_url": "https://api.openai.com/v1",
|
||||
"api_key": "OPENAI_API_KEY",
|
||||
},
|
||||
}
|
||||
|
||||
def __init__(self, model: str, beam: int = 1):
|
||||
self.model = model
|
||||
self.beam = beam
|
||||
|
||||
def _is_model_image_compatible(self, model: str) -> bool:
|
||||
"""
|
||||
Check if the model supports image inputs, accounting for model version suffixes.
|
||||
|
||||
Examples:
|
||||
'claude-3.5-sonnet-20241022' -> matches 'claude-3.5-sonnet'
|
||||
'gpt-4o-2024-05-13' -> matches 'gpt-4o'
|
||||
"""
|
||||
# Normalize the model name to lowercase
|
||||
model_lower = model.lower()
|
||||
|
||||
# First check for exact matches
|
||||
if model_lower in self.MODELS_WITH_IMAGE_SUPPORT:
|
||||
return True
|
||||
|
||||
# Check for models with version number suffixes
|
||||
for supported_model in self.MODELS_WITH_IMAGE_SUPPORT:
|
||||
if supported_model in model:
|
||||
return True
|
||||
|
||||
# Special handling for custom adaptations
|
||||
if "vision" in model_lower and any(
|
||||
gpt in model_lower for gpt in ["gpt-4", "gpt4"]
|
||||
):
|
||||
return True
|
||||
|
||||
return False
|
||||
def _get_provider_config(self, model: str) -> dict:
|
||||
"""Get provider config based on model name"""
|
||||
for provider, config in self.PROVIDERS.items():
|
||||
if provider in model:
|
||||
return config
|
||||
raise ValueError(f"No provider found for model: {model}")
|
||||
|
||||
@track_timing_async("llm_api_call")
|
||||
@retry(wait=wait_exponential(multiplier=2, min=2, max=15))
|
||||
async def acall(self, *args, **kwargs):
|
||||
max_tokens = kwargs.get("max_tokens", 2000)
|
||||
async def acall(self, **kwargs):
|
||||
model_to_use = kwargs.get("model", self.model)
|
||||
messages = kwargs.get("messages", [])
|
||||
# Get provider config
|
||||
provider_config = self._get_provider_config(model_to_use)
|
||||
# Apply model transform if specified
|
||||
if "model_transform" in provider_config:
|
||||
model_to_use = provider_config["model_transform"](model_to_use)
|
||||
# Prepare messages for text-only LLMs
|
||||
messages = remove_whitespace_blocks(messages)
|
||||
messages = merge_contiguous_messages(messages)
|
||||
# Create client
|
||||
client = AsyncOpenAI(
|
||||
base_url=provider_config["base_url"],
|
||||
api_key=os.getenv(provider_config["api_key"]),
|
||||
max_retries=0,
|
||||
)
|
||||
|
||||
# Check for image content
|
||||
has_images = has_image_content(messages)
|
||||
# Standard API call for all providers
|
||||
response = await client.chat.completions.create(
|
||||
model=model_to_use,
|
||||
messages=messages,
|
||||
max_tokens=kwargs.get("max_tokens", 256),
|
||||
temperature=kwargs.get("temperature", 0.3),
|
||||
logit_bias=kwargs.get("logit_bias"),
|
||||
n=kwargs.get("n_samples"),
|
||||
stop=kwargs.get("stop_sequences"),
|
||||
presence_penalty=kwargs.get("presence_penalty"),
|
||||
frequency_penalty=kwargs.get("frequency_penalty"),
|
||||
stream=False,
|
||||
)
|
||||
|
||||
# Validate image capability if images are present
|
||||
if has_images and not self._is_model_image_compatible(model_to_use):
|
||||
raise ValueError(
|
||||
f"Model {model_to_use} does not support image inputs, but images were provided."
|
||||
)
|
||||
|
||||
if "open-router" in model_to_use:
|
||||
# Track reasoning tokens if available
|
||||
if hasattr(response, "usage") and hasattr(response.usage, "reasoning_tokens"):
|
||||
async with timing_tracker.track_async(
|
||||
"open_router_api_call", model=model_to_use, llm=True
|
||||
"reasoning", model=model_to_use, tokens=response.usage.reasoning_tokens
|
||||
):
|
||||
client = NoRetryAsyncOpenAI(
|
||||
base_url="https://openrouter.ai/api/v1",
|
||||
api_key=os.getenv("OPEN_ROUTER_API_KEY"),
|
||||
)
|
||||
response = await client.chat.completions.create(
|
||||
model=model_to_use.replace("open-router", "").strip("-"),
|
||||
max_tokens=kwargs.get("max_tokens", 256),
|
||||
temperature=kwargs.get("temperature", 0.3),
|
||||
messages=kwargs.get("messages", None),
|
||||
logit_bias=kwargs.get("logit_bias", None),
|
||||
n=kwargs.get("n_samples", None),
|
||||
stop=kwargs.get("stop_sequences", None),
|
||||
stream=False,
|
||||
presence_penalty=kwargs.get("presence_penalty", None),
|
||||
frequency_penalty=kwargs.get("frequency_penalty", None),
|
||||
)
|
||||
return response
|
||||
pass
|
||||
|
||||
if "claude" in model_to_use:
|
||||
async with timing_tracker.track_async(
|
||||
"claude_api_call", model=model_to_use, llm=True
|
||||
):
|
||||
# Process system message
|
||||
system_message = ""
|
||||
if messages and messages[0]["role"] == "system":
|
||||
system_message = messages[0]["content"]
|
||||
if isinstance(system_message, list):
|
||||
# Extract just the text parts for system message
|
||||
system_text_parts = []
|
||||
for part in system_message:
|
||||
if isinstance(part, dict) and part.get("type") == "text":
|
||||
system_text_parts.append(part.get("text", ""))
|
||||
elif isinstance(part, str):
|
||||
system_text_parts.append(part)
|
||||
system_message = "\n".join(system_text_parts)
|
||||
system_message = system_message.strip()
|
||||
|
||||
# If the most recent message is from the assistant and ends with whitespace, clean it
|
||||
if messages and messages[-1]["role"] == "assistant":
|
||||
if isinstance(messages[-1]["content"], str):
|
||||
messages[-1]["content"] = messages[-1]["content"].strip()
|
||||
|
||||
# If the most recent message is from the assistant, add a user message to prompt the assistant
|
||||
if messages and messages[-1]["role"] == "assistant":
|
||||
messages.append({"role": "user", "content": "Success."})
|
||||
|
||||
if not has_images:
|
||||
# For text-only messages, use the standard processing
|
||||
messages = remove_whitespace_blocks(messages)
|
||||
messages = merge_contiguous_messages(messages)
|
||||
|
||||
# Format for Claude API
|
||||
anthropic_messages = []
|
||||
for msg in messages:
|
||||
if msg["role"] != "system": # System message handled separately
|
||||
anthropic_messages.append(
|
||||
{"role": msg["role"], "content": msg["content"]}
|
||||
)
|
||||
else:
|
||||
# For messages with images, use the special formatter
|
||||
anthropic_messages = format_messages_for_anthropic(
|
||||
messages, system_message
|
||||
)
|
||||
|
||||
if not system_message:
|
||||
raise RuntimeError("No system message!!")
|
||||
|
||||
try:
|
||||
client = anthropic.Anthropic(
|
||||
max_retries=0, api_key=os.getenv("ANTHROPIC_API_KEY")
|
||||
)
|
||||
# Use asyncio.to_thread for CPU-bound operations
|
||||
response = await asyncio.to_thread(
|
||||
client.messages.create,
|
||||
temperature=kwargs.get("temperature", 0.7),
|
||||
max_tokens=max_tokens,
|
||||
model=model_to_use,
|
||||
messages=anthropic_messages,
|
||||
system=system_message,
|
||||
stop_sequences=kwargs.get("stop_sequences", ["```END"]),
|
||||
)
|
||||
except Exception as e:
|
||||
print(e)
|
||||
raise
|
||||
|
||||
return response
|
||||
|
||||
elif "deepseek" in model_to_use:
|
||||
if has_images:
|
||||
raise ValueError(
|
||||
"Deepseek models do not support image inputs, but images were provided."
|
||||
)
|
||||
|
||||
async with timing_tracker.track_async(
|
||||
"deepseek_api_call", model=model_to_use, llm=True
|
||||
):
|
||||
client = NoRetryAsyncOpenAI(
|
||||
api_key=os.getenv("DEEPSEEK_API_KEY"),
|
||||
base_url="https://api.deepseek.com",
|
||||
)
|
||||
try:
|
||||
response = await client.chat.completions.create(
|
||||
model=model_to_use,
|
||||
max_tokens=kwargs.get("max_tokens", 256),
|
||||
temperature=kwargs.get("temperature", 0.3),
|
||||
messages=kwargs.get("messages", None),
|
||||
logit_bias=kwargs.get("logit_bias", None),
|
||||
n=kwargs.get("n_samples", None),
|
||||
stop=kwargs.get("stop_sequences", None),
|
||||
stream=False,
|
||||
presence_penalty=kwargs.get("presence_penalty", None),
|
||||
frequency_penalty=kwargs.get("frequency_penalty", None),
|
||||
)
|
||||
return response
|
||||
except Exception as e:
|
||||
print(e)
|
||||
raise
|
||||
|
||||
elif "gemini" in model_to_use:
|
||||
if has_images:
|
||||
raise ValueError(
|
||||
"Gemini integration doesn't support image inputs through this interface."
|
||||
)
|
||||
|
||||
async with timing_tracker.track_async(
|
||||
"gemini_api_call", model=model_to_use, llm=True
|
||||
):
|
||||
client = NoRetryAsyncOpenAI(
|
||||
api_key=os.getenv("GEMINI_API_KEY"),
|
||||
base_url="https://generativelanguage.googleapis.com/v1beta/openai/",
|
||||
)
|
||||
response = await client.chat.completions.create(
|
||||
model=model_to_use,
|
||||
max_tokens=kwargs.get("max_tokens", 256),
|
||||
temperature=kwargs.get("temperature", 0.3),
|
||||
messages=kwargs.get("messages", None),
|
||||
n=kwargs.get("n_samples", None),
|
||||
stream=False,
|
||||
)
|
||||
return response
|
||||
|
||||
elif any(model in model_to_use for model in ["llama", "Qwen"]):
|
||||
if has_images:
|
||||
raise ValueError(
|
||||
"Llama and Qwen models do not support image inputs through this interface."
|
||||
)
|
||||
|
||||
async with timing_tracker.track_async(
|
||||
"together_api_call", model=model_to_use, llm=True
|
||||
):
|
||||
client = NoRetryAsyncOpenAI(
|
||||
api_key=os.getenv("TOGETHER_API_KEY"),
|
||||
base_url="https://api.together.xyz/v1",
|
||||
)
|
||||
return await client.chat.completions.create(
|
||||
model=model_to_use,
|
||||
max_tokens=kwargs.get("max_tokens", 256),
|
||||
temperature=kwargs.get("temperature", 0.3),
|
||||
messages=kwargs.get("messages", None),
|
||||
logit_bias=kwargs.get("logit_bias", None),
|
||||
n=kwargs.get("n_samples", None),
|
||||
stop=kwargs.get("stop_sequences", None),
|
||||
stream=False,
|
||||
)
|
||||
|
||||
elif "o1-mini" in model_to_use or "o3-mini" in model_to_use:
|
||||
if has_images:
|
||||
raise ValueError(
|
||||
"Claude o1-mini and o3-mini models do not support image inputs."
|
||||
)
|
||||
|
||||
async with timing_tracker.track_async(
|
||||
"o1_mini_api_call", model=model_to_use, llm=True
|
||||
):
|
||||
client = NoRetryAsyncOpenAI(api_key=os.getenv("OPENAI_API_KEY"))
|
||||
# replace `max_tokens` with `max_completion_tokens` for OpenAI API
|
||||
if "max_tokens" in kwargs:
|
||||
kwargs.pop("max_tokens")
|
||||
messages = kwargs.get("messages")
|
||||
messages[0]["role"] = "developer"
|
||||
try:
|
||||
reasoning_length = "low"
|
||||
if "med" in model_to_use:
|
||||
reasoning_length = "medium"
|
||||
elif "high" in model_to_use:
|
||||
reasoning_length = "high"
|
||||
model = kwargs.get("model", "o3-mini")
|
||||
if "o3-mini" in model:
|
||||
model = "o3-mini"
|
||||
elif "o1-mini" in model:
|
||||
model = "o1-mini"
|
||||
|
||||
response = await client.chat.completions.create(
|
||||
*args,
|
||||
n=self.beam,
|
||||
model=model,
|
||||
messages=messages,
|
||||
stream=False,
|
||||
response_format={"type": "text"},
|
||||
reasoning_effort=reasoning_length,
|
||||
)
|
||||
|
||||
# Track reasoning metrics if available
|
||||
if hasattr(response, "usage") and hasattr(
|
||||
response.usage, "reasoning_tokens"
|
||||
):
|
||||
async with timing_tracker.track_async(
|
||||
"reasoning",
|
||||
model=model_to_use,
|
||||
tokens=response.usage.reasoning_tokens,
|
||||
reasoning_length=reasoning_length,
|
||||
):
|
||||
# This is just a marker for the timing - the actual reasoning happened in the API
|
||||
pass
|
||||
|
||||
return response
|
||||
except Exception as e:
|
||||
print(e)
|
||||
else:
|
||||
async with timing_tracker.track_async(
|
||||
"openai_api_call", model=model_to_use, llm=True
|
||||
):
|
||||
try:
|
||||
client = NoRetryAsyncOpenAI(api_key=os.getenv("OPENAI_API_KEY"))
|
||||
assert "messages" in kwargs, (
|
||||
"You must provide a list of messages to the model."
|
||||
)
|
||||
|
||||
if has_images:
|
||||
# Format messages for OpenAI with image support
|
||||
formatted_messages = format_messages_for_openai(messages)
|
||||
else:
|
||||
formatted_messages = messages
|
||||
|
||||
response = await client.chat.completions.create(
|
||||
model=model_to_use,
|
||||
max_tokens=kwargs.get("max_tokens", 256),
|
||||
temperature=kwargs.get("temperature", 0.3),
|
||||
messages=formatted_messages,
|
||||
logit_bias=kwargs.get("logit_bias", None),
|
||||
n=kwargs.get("n_samples", None),
|
||||
stop=kwargs.get("stop_sequences", None),
|
||||
stream=False,
|
||||
presence_penalty=kwargs.get("presence_penalty", None),
|
||||
frequency_penalty=kwargs.get("frequency_penalty", None),
|
||||
)
|
||||
|
||||
# Track reasoning metrics if available
|
||||
if hasattr(response, "usage") and hasattr(
|
||||
response.usage, "reasoning_tokens"
|
||||
):
|
||||
async with timing_tracker.track_async(
|
||||
"reasoning",
|
||||
model=model_to_use,
|
||||
tokens=response.usage.reasoning_tokens,
|
||||
):
|
||||
# This is just a marker for the timing - the actual reasoning happened in the API
|
||||
pass
|
||||
|
||||
return response
|
||||
except Exception as e:
|
||||
print(e)
|
||||
try:
|
||||
client = NoRetryAsyncOpenAI(api_key=os.getenv("OPENAI_API_KEY"))
|
||||
assert "messages" in kwargs, (
|
||||
"You must provide a list of messages to the model."
|
||||
)
|
||||
|
||||
# Attempt with truncated message history as fallback
|
||||
sys = kwargs.get("messages", None)[0]
|
||||
messages = [sys] + kwargs.get("messages", None)[8:]
|
||||
|
||||
if has_images:
|
||||
# Format messages for OpenAI with image support
|
||||
formatted_messages = format_messages_for_openai(messages)
|
||||
else:
|
||||
formatted_messages = messages
|
||||
|
||||
response = await client.chat.completions.create(
|
||||
model=model_to_use,
|
||||
max_tokens=kwargs.get("max_tokens", 256),
|
||||
temperature=kwargs.get("temperature", 0.3),
|
||||
messages=formatted_messages,
|
||||
logit_bias=kwargs.get("logit_bias", None),
|
||||
n=kwargs.get("n_samples", None),
|
||||
stop=kwargs.get("stop_sequences", None),
|
||||
stream=False,
|
||||
presence_penalty=kwargs.get("presence_penalty", None),
|
||||
frequency_penalty=kwargs.get("frequency_penalty", None),
|
||||
)
|
||||
|
||||
# Track reasoning metrics if available
|
||||
if hasattr(response, "usage") and hasattr(
|
||||
response.usage, "reasoning_tokens"
|
||||
):
|
||||
async with timing_tracker.track_async(
|
||||
"reasoning",
|
||||
model=model_to_use,
|
||||
tokens=response.usage.reasoning_tokens,
|
||||
):
|
||||
# This is just a marker for the timing - the actual reasoning happened in the API
|
||||
pass
|
||||
|
||||
return response
|
||||
except Exception as e:
|
||||
print(e)
|
||||
raise
|
||||
|
||||
def call(self, *args, **kwargs):
|
||||
# For the synchronous version, we should also implement image support,
|
||||
# but I'll leave this method unchanged as the focus is on the async version.
|
||||
# The same pattern would be applied here as in acall.
|
||||
max_tokens = kwargs.get("max_tokens", 1500)
|
||||
model_to_use = kwargs.get("model", self.model)
|
||||
|
||||
messages = kwargs.get("messages", [])
|
||||
has_images = self._has_image_content(messages)
|
||||
|
||||
# Validate image capability if images are present
|
||||
if has_images and not self._is_model_image_compatible(model_to_use):
|
||||
raise ValueError(
|
||||
f"Model {model_to_use} does not support image inputs, but images were provided."
|
||||
)
|
||||
|
||||
if "claude" in model_to_use:
|
||||
# Process system message
|
||||
system_message = ""
|
||||
if messages and messages[0]["role"] == "system":
|
||||
system_message = messages[0]["content"]
|
||||
if isinstance(system_message, list):
|
||||
# Extract just the text parts for system message
|
||||
system_text_parts = []
|
||||
for part in system_message:
|
||||
if isinstance(part, dict) and part.get("type") == "text":
|
||||
system_text_parts.append(part.get("text", ""))
|
||||
elif isinstance(part, str):
|
||||
system_text_parts.append(part)
|
||||
system_message = "\n".join(system_text_parts)
|
||||
system_message = system_message.strip()
|
||||
|
||||
# Remove final assistant content that ends with trailing whitespace
|
||||
if messages[-1]["role"] == "assistant":
|
||||
if isinstance(messages[-1]["content"], str):
|
||||
messages[-1]["content"] = messages[-1]["content"].strip()
|
||||
|
||||
# If the most recent message is from the assistant, add a user message to prompt the assistant
|
||||
if messages[-1]["role"] == "assistant":
|
||||
messages.append({"role": "user", "content": "Success."})
|
||||
|
||||
if not has_images:
|
||||
# Standard text processing
|
||||
messages = self.remove_whitespace_blocks(messages)
|
||||
messages = self.merge_contiguous_messages(messages)
|
||||
|
||||
# Format for Claude API
|
||||
anthropic_messages = []
|
||||
for msg in messages:
|
||||
if msg["role"] != "system": # System message handled separately
|
||||
anthropic_messages.append(
|
||||
{"role": msg["role"], "content": msg["content"]}
|
||||
)
|
||||
else:
|
||||
# Format with image support
|
||||
anthropic_messages = self._format_messages_for_anthropic(
|
||||
messages, system_message
|
||||
)
|
||||
|
||||
try:
|
||||
client = anthropic.Anthropic()
|
||||
response = client.messages.create(
|
||||
temperature=kwargs.get("temperature", 0.7),
|
||||
max_tokens=max_tokens,
|
||||
model=model_to_use,
|
||||
messages=anthropic_messages,
|
||||
system=system_message,
|
||||
stop_sequences=kwargs.get("stop_sequences", None),
|
||||
)
|
||||
except Exception as e:
|
||||
print(e)
|
||||
raise
|
||||
|
||||
return response
|
||||
|
||||
elif "deepseek" in model_to_use:
|
||||
if has_images:
|
||||
raise ValueError(
|
||||
"Deepseek models do not support image inputs, but images were provided."
|
||||
)
|
||||
|
||||
client = OpenAI(
|
||||
api_key=os.getenv("DEEPSEEK_API_KEY"),
|
||||
base_url="https://api.deepseek.com",
|
||||
)
|
||||
response = client.chat.completions.create(
|
||||
*args,
|
||||
**kwargs,
|
||||
model=model_to_use,
|
||||
presence_penalty=kwargs.get("presence_penalty", None),
|
||||
frequency_penalty=kwargs.get("frequency_penalty", None),
|
||||
logit_bias=kwargs.get("logit_bias", None),
|
||||
n=kwargs.get("n_samples", None),
|
||||
stop=kwargs.get("stop_sequences", None),
|
||||
stream=False,
|
||||
)
|
||||
return response
|
||||
|
||||
elif "o1-mini" in model_to_use:
|
||||
if has_images:
|
||||
raise ValueError("Claude o1-mini model does not support image inputs.")
|
||||
|
||||
client = OpenAI(api_key=os.getenv("OPENAI_API_KEY"))
|
||||
# replace `max_tokens` with `max_completion_tokens` for OpenAI API
|
||||
if "max_tokens" in kwargs:
|
||||
kwargs.pop("max_tokens")
|
||||
|
||||
return client.chat.completions.create(
|
||||
*args, n=self.beam, **kwargs, stream=False
|
||||
)
|
||||
else:
|
||||
client = OpenAI(api_key=os.getenv("OPENAI_API_KEY"))
|
||||
assert "messages" in kwargs, (
|
||||
"You must provide a list of messages to the model."
|
||||
)
|
||||
|
||||
if has_images:
|
||||
# Format messages for OpenAI with image support
|
||||
formatted_messages = self._format_messages_for_openai(messages)
|
||||
else:
|
||||
formatted_messages = messages
|
||||
|
||||
return client.chat.completions.create(
|
||||
model=model_to_use,
|
||||
max_tokens=kwargs.get("max_tokens", 256),
|
||||
temperature=kwargs.get("temperature", 0.3),
|
||||
messages=formatted_messages,
|
||||
logit_bias=kwargs.get("logit_bias", None),
|
||||
n=kwargs.get("n_samples", None),
|
||||
stop=kwargs.get("stop_sequences", None),
|
||||
stream=False,
|
||||
)
|
||||
return response
|
||||
|
1
fle/env/gym_env/config.py
vendored
1
fle/env/gym_env/config.py
vendored
@@ -27,7 +27,6 @@ class GymEvalConfig:
|
||||
task: Optional[TaskABC] = None
|
||||
agent_cards: Optional[List[AgentCard]] = None
|
||||
env_id: Optional[str] = None # Gym environment ID for registry-based creation
|
||||
instance_id: Optional[int] = None # Which container to use for this evaluation
|
||||
|
||||
def __post_init__(self):
|
||||
if self.task is None and hasattr(self.agents[0], "task"):
|
||||
|
16
fle/env/gym_env/registry.py
vendored
16
fle/env/gym_env/registry.py
vendored
@@ -12,6 +12,8 @@ from fle.env import FactorioInstance
|
||||
from fle.env.gym_env.environment import FactorioGymEnv
|
||||
from fle.eval.tasks import TaskFactory
|
||||
|
||||
PORT_OFFSET = int(os.environ["PORT_OFFSET"])
|
||||
|
||||
|
||||
@dataclass
|
||||
class GymEnvironmentSpec:
|
||||
@@ -127,7 +129,7 @@ class FactorioGymRegistry:
|
||||
_registry = FactorioGymRegistry()
|
||||
|
||||
|
||||
def make_factorio_env(env_spec: GymEnvironmentSpec, instance_id: int) -> FactorioGymEnv:
|
||||
def make_factorio_env(env_spec: GymEnvironmentSpec, run_idx: int) -> FactorioGymEnv:
|
||||
"""Factory function to create a Factorio gym environment"""
|
||||
|
||||
# Create task from the task definition
|
||||
@@ -143,8 +145,16 @@ def make_factorio_env(env_spec: GymEnvironmentSpec, instance_id: int) -> Factori
|
||||
ips, udp_ports, tcp_ports = get_local_container_ips()
|
||||
if len(tcp_ports) == 0:
|
||||
raise RuntimeError("No Factorio containers available")
|
||||
address = ips[instance_id]
|
||||
tcp_port = tcp_ports[instance_id]
|
||||
|
||||
# Apply port offset for multiple terminal sessions
|
||||
container_idx = PORT_OFFSET + run_idx
|
||||
if container_idx >= len(tcp_ports):
|
||||
raise RuntimeError(
|
||||
f"Container index {container_idx} (PORT_OFFSET={PORT_OFFSET} + run_idx={run_idx}) exceeds available containers ({len(tcp_ports)})"
|
||||
)
|
||||
|
||||
address = ips[container_idx]
|
||||
tcp_port = tcp_ports[container_idx]
|
||||
|
||||
common_kwargs = {
|
||||
"address": address,
|
||||
|
9
fle/env/gym_env/run_eval.py
vendored
9
fle/env/gym_env/run_eval.py
vendored
@@ -49,7 +49,7 @@ async def run_trajectory(run_idx: int, config: GymEvalConfig):
|
||||
"""Run a single gym evaluation process"""
|
||||
db_client = await create_db_client()
|
||||
|
||||
gym_env = gym.make(config.env_id, instance_id=config.instance_id)
|
||||
gym_env = gym.make(config.env_id, run_idx=run_idx)
|
||||
|
||||
log_dir = os.path.join(".fle", "trajectory_logs", f"v{config.version}")
|
||||
runner = GymTrajectoryRunner(
|
||||
@@ -63,16 +63,16 @@ async def run_trajectory(run_idx: int, config: GymEvalConfig):
|
||||
await db_client.cleanup()
|
||||
|
||||
|
||||
async def main(run_config, offset):
|
||||
async def main(config_path):
|
||||
# Read and validate run configurations
|
||||
run_config = get_validated_run_configs(run_config)
|
||||
run_configs = get_validated_run_configs(config_path)
|
||||
# Get starting version number for new runs
|
||||
base_version = await get_next_version()
|
||||
version_offset = 0
|
||||
|
||||
# Create and start processes
|
||||
processes = []
|
||||
for run_idx, run_config in enumerate(run_config):
|
||||
for run_idx, run_config in enumerate(run_configs):
|
||||
# Get environment info from registry
|
||||
env_info = get_environment_info(run_config.env_id)
|
||||
if env_info is None:
|
||||
@@ -116,7 +116,6 @@ async def main(run_config, offset):
|
||||
task=task,
|
||||
agent_cards=agent_cards,
|
||||
env_id=run_config.env_id,
|
||||
instance_id=run_idx + offset,
|
||||
)
|
||||
# Ensure agent cards are properly set for a2a functionality
|
||||
assert config.agent_cards is not None
|
||||
|
10
fle/run.py
10
fle/run.py
@@ -45,7 +45,7 @@ def fle_cluster(args):
|
||||
def fle_eval(args):
|
||||
try:
|
||||
config_path = str(Path(args.config))
|
||||
asyncio.run(run_eval(config_path, args.offset))
|
||||
asyncio.run(run_eval(config_path))
|
||||
except Exception as e:
|
||||
print(f"Error: {e}", file=sys.stderr)
|
||||
sys.exit(1)
|
||||
@@ -58,7 +58,7 @@ def main():
|
||||
formatter_class=argparse.RawDescriptionHelpFormatter,
|
||||
epilog="""
|
||||
Examples:
|
||||
fle eval --config configs/gym_run_config.json --offset n
|
||||
fle eval --config configs/gym_run_config.json
|
||||
fle cluster [start|stop|restart|help] [-n N] [-s SCENARIO]
|
||||
""",
|
||||
)
|
||||
@@ -80,12 +80,6 @@ Examples:
|
||||
)
|
||||
parser_eval = subparsers.add_parser("eval", help="Run experiment")
|
||||
parser_eval.add_argument("--config", required=True, help="Path to run config JSON")
|
||||
parser_eval.add_argument(
|
||||
"--offset",
|
||||
type=int,
|
||||
required=True,
|
||||
help="Offset to add to instance_id selection",
|
||||
)
|
||||
args = parser.parse_args()
|
||||
if args.command:
|
||||
fle_init()
|
||||
|
Reference in New Issue
Block a user