mirror of
https://github.com/browser-use/browser-use.git
synced 2025-12-03 19:34:53 +00:00
Merge branch 'main' into cursor/check-agent-urls-in-docs-claude-4.5-opus-high-8510
This commit is contained in:
@@ -22,6 +22,7 @@ from browser_use.agent.cloud_events import (
|
||||
)
|
||||
from browser_use.agent.message_manager.utils import save_conversation
|
||||
from browser_use.llm.base import BaseChatModel
|
||||
from browser_use.llm.exceptions import ModelProviderError, ModelRateLimitError
|
||||
from browser_use.llm.messages import BaseMessage, ContentPartImageParam, ContentPartTextParam, UserMessage
|
||||
from browser_use.tokens.service import TokenCost
|
||||
|
||||
@@ -166,6 +167,7 @@ class Agent(Generic[Context, AgentStructuredOutput]):
|
||||
demo_mode: bool | None = None,
|
||||
max_history_items: int | None = None,
|
||||
page_extraction_llm: BaseChatModel | None = None,
|
||||
fallback_llm: BaseChatModel | None = None,
|
||||
use_judge: bool = True,
|
||||
ground_truth: str | None = None,
|
||||
judge_llm: BaseChatModel | None = None,
|
||||
@@ -283,6 +285,11 @@ class Agent(Generic[Context, AgentStructuredOutput]):
|
||||
self.task = self._enhance_task_with_schema(task, output_model_schema)
|
||||
self.llm = llm
|
||||
self.judge_llm = judge_llm
|
||||
|
||||
# Fallback LLM configuration
|
||||
self._fallback_llm: BaseChatModel | None = fallback_llm
|
||||
self._using_fallback_llm: bool = False
|
||||
self._original_llm: BaseChatModel = llm # Store original for reference
|
||||
self.directly_open_url = directly_open_url
|
||||
self.include_recent_events = include_recent_events
|
||||
self._url_shortening_limit = _url_shortening_limit
|
||||
@@ -539,6 +546,16 @@ class Agent(Generic[Context, AgentStructuredOutput]):
|
||||
assert self.browser_session is not None, 'BrowserSession is not set up'
|
||||
return self.browser_session.browser_profile
|
||||
|
||||
@property
|
||||
def is_using_fallback_llm(self) -> bool:
|
||||
"""Check if the agent is currently using the fallback LLM."""
|
||||
return self._using_fallback_llm
|
||||
|
||||
@property
|
||||
def current_llm_model(self) -> str:
|
||||
"""Get the model name of the currently active LLM."""
|
||||
return self.llm.model if hasattr(self.llm, 'model') else 'unknown'
|
||||
|
||||
async def _check_and_update_downloads(self, context: str = '') -> None:
|
||||
"""Check for new downloads and update available file paths."""
|
||||
if not self.has_downloads_path:
|
||||
@@ -1357,6 +1374,68 @@ class Agent(Generic[Context, AgentStructuredOutput]):
|
||||
except ValidationError:
|
||||
# Just re-raise - Pydantic's validation errors are already descriptive
|
||||
raise
|
||||
except (ModelRateLimitError, ModelProviderError) as e:
|
||||
# Check if we can switch to a fallback LLM
|
||||
if not self._try_switch_to_fallback_llm(e):
|
||||
# No fallback available, re-raise the original error
|
||||
raise
|
||||
# Retry with the fallback LLM
|
||||
return await self.get_model_output(input_messages)
|
||||
|
||||
def _try_switch_to_fallback_llm(self, error: ModelRateLimitError | ModelProviderError) -> bool:
|
||||
"""
|
||||
Attempt to switch to a fallback LLM after a rate limit or provider error.
|
||||
|
||||
Returns True if successfully switched to a fallback, False if no fallback available.
|
||||
Once switched, the agent will use the fallback LLM for the rest of the run.
|
||||
"""
|
||||
# Already using fallback - can't switch again
|
||||
if self._using_fallback_llm:
|
||||
self.logger.warning(
|
||||
f'⚠️ Fallback LLM also failed ({type(error).__name__}: {error.message}), no more fallbacks available'
|
||||
)
|
||||
return False
|
||||
|
||||
# Check if error is retryable (rate limit, auth errors, or server errors)
|
||||
# 401: API key invalid/expired - fallback to different provider
|
||||
# 402: Insufficient credits/payment required - fallback to different provider
|
||||
# 429: Rate limit exceeded
|
||||
# 500, 502, 503, 504: Server errors
|
||||
retryable_status_codes = {401, 402, 429, 500, 502, 503, 504}
|
||||
is_retryable = isinstance(error, ModelRateLimitError) or (
|
||||
hasattr(error, 'status_code') and error.status_code in retryable_status_codes
|
||||
)
|
||||
|
||||
if not is_retryable:
|
||||
return False
|
||||
|
||||
# Check if we have a fallback LLM configured
|
||||
if self._fallback_llm is None:
|
||||
self.logger.warning(f'⚠️ LLM error ({type(error).__name__}: {error.message}) but no fallback_llm configured')
|
||||
return False
|
||||
|
||||
self._log_fallback_switch(error, self._fallback_llm)
|
||||
|
||||
# Switch to the fallback LLM
|
||||
self.llm = self._fallback_llm
|
||||
self._using_fallback_llm = True
|
||||
|
||||
# Register the fallback LLM for token cost tracking
|
||||
self.token_cost_service.register_llm(self._fallback_llm)
|
||||
|
||||
return True
|
||||
|
||||
def _log_fallback_switch(self, error: ModelRateLimitError | ModelProviderError, fallback: BaseChatModel) -> None:
|
||||
"""Log when switching to a fallback LLM."""
|
||||
original_model = self._original_llm.model if hasattr(self._original_llm, 'model') else 'unknown'
|
||||
fallback_model = fallback.model if hasattr(fallback, 'model') else 'unknown'
|
||||
error_type = type(error).__name__
|
||||
status_code = getattr(error, 'status_code', 'N/A')
|
||||
|
||||
self.logger.warning(
|
||||
f'⚠️ Primary LLM ({original_model}) failed with {error_type} (status={status_code}), '
|
||||
f'switching to fallback LLM ({fallback_model})'
|
||||
)
|
||||
|
||||
async def _log_agent_run(self) -> None:
|
||||
"""Log the agent run"""
|
||||
|
||||
@@ -15,6 +15,7 @@ import httpx
|
||||
from pydantic import BaseModel
|
||||
|
||||
from browser_use.llm.base import BaseChatModel
|
||||
from browser_use.llm.exceptions import ModelProviderError, ModelRateLimitError
|
||||
from browser_use.llm.messages import BaseMessage
|
||||
from browser_use.llm.views import ChatInvokeCompletion
|
||||
from browser_use.observability import observe
|
||||
@@ -240,7 +241,7 @@ class ChatBrowserUse(BaseChatModel):
|
||||
return response.json()
|
||||
|
||||
def _raise_http_error(self, e: httpx.HTTPStatusError) -> None:
|
||||
"""Raise a ValueError with appropriate error message for HTTP errors."""
|
||||
"""Raise appropriate ModelProviderError for HTTP errors."""
|
||||
error_detail = ''
|
||||
try:
|
||||
error_data = e.response.json()
|
||||
@@ -248,12 +249,18 @@ class ChatBrowserUse(BaseChatModel):
|
||||
except Exception:
|
||||
error_detail = str(e)
|
||||
|
||||
if e.response.status_code == 401:
|
||||
raise ValueError(f'Invalid API key. {error_detail}')
|
||||
elif e.response.status_code == 402:
|
||||
raise ValueError(f'Insufficient credits. {error_detail}')
|
||||
status_code = e.response.status_code
|
||||
|
||||
if status_code == 401:
|
||||
raise ModelProviderError(message=f'Invalid API key. {error_detail}', status_code=401, model=self.name)
|
||||
elif status_code == 402:
|
||||
raise ModelProviderError(message=f'Insufficient credits. {error_detail}', status_code=402, model=self.name)
|
||||
elif status_code == 429:
|
||||
raise ModelRateLimitError(message=f'Rate limit exceeded. {error_detail}', status_code=429, model=self.name)
|
||||
elif status_code in {500, 502, 503, 504}:
|
||||
raise ModelProviderError(message=f'Server error. {error_detail}', status_code=status_code, model=self.name)
|
||||
else:
|
||||
raise ValueError(f'API request failed: {error_detail}')
|
||||
raise ModelProviderError(message=f'API request failed: {error_detail}', status_code=status_code, model=self.name)
|
||||
|
||||
def _serialize_message(self, message: BaseMessage) -> dict:
|
||||
"""Serialize a message to JSON format."""
|
||||
|
||||
@@ -17,6 +17,9 @@ mode: "wide"
|
||||
- `vision_detail_level` (default: `'auto'`): Screenshot detail level - `'low'`, `'high'`, or `'auto'`
|
||||
- `page_extraction_llm`: Separate LLM model for page content extraction. You can choose a small & fast model because it only needs to extract text from the page (default: same as `llm`)
|
||||
|
||||
### Fallback & Resilience
|
||||
- `fallback_llm`: Backup LLM to use when the primary LLM fails. The primary LLM will first exhaust its own retry logic (typically 5 attempts with exponential backoff), and only then switch to the fallback. Triggers on rate limits (429), authentication errors (401), payment/credit errors (402), or server errors (500, 502, 503, 504). Once switched, the fallback is used for the rest of the run. [Example](https://github.com/browser-use/browser-use/blob/main/examples/features/fallback_model.py)
|
||||
|
||||
### Actions & Behavior
|
||||
- `initial_actions`: List of actions to run before the main task without LLM. [Example](https://github.com/browser-use/browser-use/blob/main/examples/features/initial_actions.py)
|
||||
- `max_actions_per_step` (default: `4`): Maximum actions per step, e.g. for form filling the agent can output 4 fields at once. We execute the actions until the page changes.
|
||||
|
||||
55
examples/features/fallback_model.py
Normal file
55
examples/features/fallback_model.py
Normal file
@@ -0,0 +1,55 @@
|
||||
"""
|
||||
Example: Using a fallback LLM model.
|
||||
|
||||
When the primary LLM fails with rate limits (429), authentication errors (401),
|
||||
payment/credit errors (402), or server errors (500, 502, 503, 504), the agent
|
||||
automatically switches to the fallback model and continues execution.
|
||||
|
||||
Note: The primary LLM will first exhaust its own retry logic (typically 5 attempts
|
||||
with exponential backoff) before the fallback is triggered. This means transient errors
|
||||
are handled by the provider's built-in retries, and the fallback only kicks in when
|
||||
the provider truly can't recover.
|
||||
|
||||
This is useful for:
|
||||
- High availability: Keep your agent running even when one provider has issues
|
||||
- Cost optimization: Use a cheaper model as fallback when the primary is rate limited
|
||||
- Multi-provider resilience: Switch between OpenAI, Anthropic, Google, etc.
|
||||
|
||||
@dev You need to add OPENAI_API_KEY and ANTHROPIC_API_KEY to your environment variables.
|
||||
"""
|
||||
|
||||
import asyncio
|
||||
import os
|
||||
import sys
|
||||
|
||||
sys.path.append(os.path.dirname(os.path.dirname(os.path.dirname(os.path.abspath(__file__)))))
|
||||
|
||||
from dotenv import load_dotenv
|
||||
|
||||
load_dotenv()
|
||||
|
||||
from browser_use import Agent
|
||||
from browser_use.llm import ChatAnthropic, ChatOpenAI
|
||||
|
||||
llm = ChatAnthropic(model='claude-sonnet-4-0')
|
||||
fallback_llm = ChatOpenAI(model='gpt-4o')
|
||||
|
||||
agent = Agent(
|
||||
task='Go to github.com and find the browser-use repository',
|
||||
llm=llm,
|
||||
fallback_llm=fallback_llm,
|
||||
)
|
||||
|
||||
|
||||
async def main():
|
||||
result = await agent.run()
|
||||
print(result)
|
||||
|
||||
# You can check if fallback was used:
|
||||
if agent.is_using_fallback_llm:
|
||||
print('Note: Agent switched to fallback LLM during execution')
|
||||
print(f'Current model: {agent.current_llm_model}')
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
asyncio.run(main())
|
||||
439
tests/ci/test_fallback_llm.py
Normal file
439
tests/ci/test_fallback_llm.py
Normal file
@@ -0,0 +1,439 @@
|
||||
"""
|
||||
Tests for the fallback_llm feature in Agent.
|
||||
|
||||
Tests verify that when the primary LLM fails with rate limit (429) or server errors (503, 502, 500, 504),
|
||||
the agent automatically switches to the fallback LLM and continues execution.
|
||||
"""
|
||||
|
||||
from unittest.mock import AsyncMock
|
||||
|
||||
import pytest
|
||||
|
||||
from browser_use.agent.views import AgentOutput
|
||||
from browser_use.llm import BaseChatModel
|
||||
from browser_use.llm.exceptions import ModelProviderError, ModelRateLimitError
|
||||
from browser_use.llm.views import ChatInvokeCompletion
|
||||
from browser_use.tools.service import Tools
|
||||
|
||||
|
||||
def create_mock_llm(
|
||||
model_name: str = 'mock-llm',
|
||||
should_fail: bool = False,
|
||||
fail_with: type[Exception] | None = None,
|
||||
fail_status_code: int = 429,
|
||||
fail_message: str = 'Rate limit exceeded',
|
||||
) -> BaseChatModel:
|
||||
"""Create a mock LLM for testing.
|
||||
|
||||
Args:
|
||||
model_name: Name of the mock model
|
||||
should_fail: If True, the LLM will raise an exception
|
||||
fail_with: Exception type to raise (ModelRateLimitError or ModelProviderError)
|
||||
fail_status_code: HTTP status code for the error
|
||||
fail_message: Error message
|
||||
"""
|
||||
tools = Tools()
|
||||
ActionModel = tools.registry.create_action_model()
|
||||
AgentOutputWithActions = AgentOutput.type_with_custom_actions(ActionModel)
|
||||
|
||||
llm = AsyncMock(spec=BaseChatModel)
|
||||
llm.model = model_name
|
||||
llm._verified_api_keys = True
|
||||
llm.provider = 'mock'
|
||||
llm.name = model_name
|
||||
llm.model_name = model_name
|
||||
|
||||
default_done_action = """
|
||||
{
|
||||
"thinking": "null",
|
||||
"evaluation_previous_goal": "Successfully completed the task",
|
||||
"memory": "Task completed",
|
||||
"next_goal": "Task completed",
|
||||
"action": [
|
||||
{
|
||||
"done": {
|
||||
"text": "Task completed successfully",
|
||||
"success": true
|
||||
}
|
||||
}
|
||||
]
|
||||
}
|
||||
"""
|
||||
|
||||
async def mock_ainvoke(*args, **kwargs):
|
||||
if should_fail:
|
||||
if fail_with == ModelRateLimitError:
|
||||
raise ModelRateLimitError(message=fail_message, status_code=fail_status_code, model=model_name)
|
||||
elif fail_with == ModelProviderError:
|
||||
raise ModelProviderError(message=fail_message, status_code=fail_status_code, model=model_name)
|
||||
else:
|
||||
raise Exception(fail_message)
|
||||
|
||||
output_format = kwargs.get('output_format')
|
||||
if output_format is None:
|
||||
return ChatInvokeCompletion(completion=default_done_action, usage=None)
|
||||
else:
|
||||
parsed = output_format.model_validate_json(default_done_action)
|
||||
return ChatInvokeCompletion(completion=parsed, usage=None)
|
||||
|
||||
llm.ainvoke.side_effect = mock_ainvoke
|
||||
|
||||
return llm
|
||||
|
||||
|
||||
class TestFallbackLLMParameter:
|
||||
"""Test fallback_llm parameter initialization."""
|
||||
|
||||
def test_fallback_llm_none_by_default(self):
|
||||
"""Verify fallback_llm defaults to None."""
|
||||
from browser_use import Agent
|
||||
|
||||
primary = create_mock_llm('primary-model')
|
||||
agent = Agent(task='Test task', llm=primary)
|
||||
|
||||
assert agent._fallback_llm is None
|
||||
assert agent._using_fallback_llm is False
|
||||
assert agent._original_llm is primary
|
||||
|
||||
def test_fallback_llm_single_model(self):
|
||||
"""Test passing a fallback LLM."""
|
||||
from browser_use import Agent
|
||||
|
||||
primary = create_mock_llm('primary-model')
|
||||
fallback = create_mock_llm('fallback-model')
|
||||
|
||||
agent = Agent(task='Test task', llm=primary, fallback_llm=fallback)
|
||||
|
||||
assert agent._fallback_llm is fallback
|
||||
assert agent._using_fallback_llm is False
|
||||
|
||||
def test_public_properties(self):
|
||||
"""Test the public properties for fallback status."""
|
||||
from browser_use import Agent
|
||||
|
||||
primary = create_mock_llm('primary-model')
|
||||
fallback = create_mock_llm('fallback-model')
|
||||
|
||||
agent = Agent(task='Test task', llm=primary, fallback_llm=fallback)
|
||||
|
||||
# Before fallback
|
||||
assert agent.is_using_fallback_llm is False
|
||||
assert agent.current_llm_model == 'primary-model'
|
||||
|
||||
# Trigger fallback
|
||||
error = ModelRateLimitError(message='Rate limit', status_code=429, model='primary')
|
||||
agent._try_switch_to_fallback_llm(error)
|
||||
|
||||
# After fallback
|
||||
assert agent.is_using_fallback_llm is True
|
||||
assert agent.current_llm_model == 'fallback-model'
|
||||
|
||||
|
||||
class TestFallbackLLMSwitching:
|
||||
"""Test the fallback switching logic in _try_switch_to_fallback_llm."""
|
||||
|
||||
def test_switch_on_rate_limit_error(self):
|
||||
"""Test that agent switches to fallback on ModelRateLimitError."""
|
||||
from browser_use import Agent
|
||||
|
||||
primary = create_mock_llm('primary-model')
|
||||
fallback = create_mock_llm('fallback-model')
|
||||
|
||||
agent = Agent(task='Test task', llm=primary, fallback_llm=fallback)
|
||||
|
||||
error = ModelRateLimitError(message='Rate limit exceeded', status_code=429, model='primary-model')
|
||||
result = agent._try_switch_to_fallback_llm(error)
|
||||
|
||||
assert result is True
|
||||
assert agent.llm is fallback
|
||||
assert agent._using_fallback_llm is True
|
||||
|
||||
def test_switch_on_503_error(self):
|
||||
"""Test that agent switches to fallback on 503 Service Unavailable."""
|
||||
from browser_use import Agent
|
||||
|
||||
primary = create_mock_llm('primary-model')
|
||||
fallback = create_mock_llm('fallback-model')
|
||||
|
||||
agent = Agent(task='Test task', llm=primary, fallback_llm=fallback)
|
||||
|
||||
error = ModelProviderError(message='Service unavailable', status_code=503, model='primary-model')
|
||||
result = agent._try_switch_to_fallback_llm(error)
|
||||
|
||||
assert result is True
|
||||
assert agent.llm is fallback
|
||||
assert agent._using_fallback_llm is True
|
||||
|
||||
def test_switch_on_500_error(self):
|
||||
"""Test that agent switches to fallback on 500 Internal Server Error."""
|
||||
from browser_use import Agent
|
||||
|
||||
primary = create_mock_llm('primary-model')
|
||||
fallback = create_mock_llm('fallback-model')
|
||||
|
||||
agent = Agent(task='Test task', llm=primary, fallback_llm=fallback)
|
||||
|
||||
error = ModelProviderError(message='Internal server error', status_code=500, model='primary-model')
|
||||
result = agent._try_switch_to_fallback_llm(error)
|
||||
|
||||
assert result is True
|
||||
assert agent.llm is fallback
|
||||
|
||||
def test_switch_on_502_error(self):
|
||||
"""Test that agent switches to fallback on 502 Bad Gateway."""
|
||||
from browser_use import Agent
|
||||
|
||||
primary = create_mock_llm('primary-model')
|
||||
fallback = create_mock_llm('fallback-model')
|
||||
|
||||
agent = Agent(task='Test task', llm=primary, fallback_llm=fallback)
|
||||
|
||||
error = ModelProviderError(message='Bad gateway', status_code=502, model='primary-model')
|
||||
result = agent._try_switch_to_fallback_llm(error)
|
||||
|
||||
assert result is True
|
||||
assert agent.llm is fallback
|
||||
|
||||
def test_no_switch_on_400_error(self):
|
||||
"""Test that agent does NOT switch on 400 Bad Request (not retryable)."""
|
||||
from browser_use import Agent
|
||||
|
||||
primary = create_mock_llm('primary-model')
|
||||
fallback = create_mock_llm('fallback-model')
|
||||
|
||||
agent = Agent(task='Test task', llm=primary, fallback_llm=fallback)
|
||||
|
||||
error = ModelProviderError(message='Bad request', status_code=400, model='primary-model')
|
||||
result = agent._try_switch_to_fallback_llm(error)
|
||||
|
||||
assert result is False
|
||||
assert agent.llm is primary # Still using primary
|
||||
assert agent._using_fallback_llm is False
|
||||
|
||||
def test_switch_on_401_error(self):
|
||||
"""Test that agent switches to fallback on 401 Unauthorized (API key error)."""
|
||||
from browser_use import Agent
|
||||
|
||||
primary = create_mock_llm('primary-model')
|
||||
fallback = create_mock_llm('fallback-model')
|
||||
|
||||
agent = Agent(task='Test task', llm=primary, fallback_llm=fallback)
|
||||
|
||||
error = ModelProviderError(message='Invalid API key', status_code=401, model='primary-model')
|
||||
result = agent._try_switch_to_fallback_llm(error)
|
||||
|
||||
assert result is True
|
||||
assert agent.llm is fallback
|
||||
assert agent._using_fallback_llm is True
|
||||
|
||||
def test_switch_on_402_error(self):
|
||||
"""Test that agent switches to fallback on 402 Payment Required (insufficient credits)."""
|
||||
from browser_use import Agent
|
||||
|
||||
primary = create_mock_llm('primary-model')
|
||||
fallback = create_mock_llm('fallback-model')
|
||||
|
||||
agent = Agent(task='Test task', llm=primary, fallback_llm=fallback)
|
||||
|
||||
error = ModelProviderError(message='Insufficient credits', status_code=402, model='primary-model')
|
||||
result = agent._try_switch_to_fallback_llm(error)
|
||||
|
||||
assert result is True
|
||||
assert agent.llm is fallback
|
||||
assert agent._using_fallback_llm is True
|
||||
|
||||
def test_no_switch_when_no_fallback_configured(self):
|
||||
"""Test that agent returns False when no fallback is configured."""
|
||||
from browser_use import Agent
|
||||
|
||||
primary = create_mock_llm('primary-model')
|
||||
agent = Agent(task='Test task', llm=primary)
|
||||
|
||||
error = ModelRateLimitError(message='Rate limit exceeded', status_code=429, model='primary-model')
|
||||
result = agent._try_switch_to_fallback_llm(error)
|
||||
|
||||
assert result is False
|
||||
assert agent.llm is primary
|
||||
|
||||
def test_no_switch_when_already_using_fallback(self):
|
||||
"""Test that agent doesn't switch again when already using fallback."""
|
||||
from browser_use import Agent
|
||||
|
||||
primary = create_mock_llm('primary-model')
|
||||
fallback = create_mock_llm('fallback-model')
|
||||
|
||||
agent = Agent(task='Test task', llm=primary, fallback_llm=fallback)
|
||||
|
||||
# First switch succeeds
|
||||
error = ModelRateLimitError(message='Rate limit', status_code=429, model='primary')
|
||||
result = agent._try_switch_to_fallback_llm(error)
|
||||
assert result is True
|
||||
assert agent.llm is fallback
|
||||
|
||||
# Second switch fails - already using fallback
|
||||
result = agent._try_switch_to_fallback_llm(error)
|
||||
assert result is False
|
||||
assert agent.llm is fallback # Still on fallback
|
||||
|
||||
|
||||
class TestFallbackLLMIntegration:
|
||||
"""Integration tests for fallback LLM behavior in get_model_output."""
|
||||
|
||||
def _create_failing_mock_llm(
|
||||
self,
|
||||
model_name: str,
|
||||
fail_with: type[Exception],
|
||||
fail_status_code: int = 429,
|
||||
fail_message: str = 'Rate limit exceeded',
|
||||
) -> BaseChatModel:
|
||||
"""Create a mock LLM that always fails with the specified error."""
|
||||
llm = AsyncMock(spec=BaseChatModel)
|
||||
llm.model = model_name
|
||||
llm._verified_api_keys = True
|
||||
llm.provider = 'mock'
|
||||
llm.name = model_name
|
||||
llm.model_name = model_name
|
||||
|
||||
async def mock_ainvoke(*args, **kwargs):
|
||||
if fail_with == ModelRateLimitError:
|
||||
raise ModelRateLimitError(message=fail_message, status_code=fail_status_code, model=model_name)
|
||||
elif fail_with == ModelProviderError:
|
||||
raise ModelProviderError(message=fail_message, status_code=fail_status_code, model=model_name)
|
||||
else:
|
||||
raise Exception(fail_message)
|
||||
|
||||
llm.ainvoke.side_effect = mock_ainvoke
|
||||
return llm
|
||||
|
||||
def _create_succeeding_mock_llm(self, model_name: str, agent) -> BaseChatModel:
|
||||
"""Create a mock LLM that succeeds and returns a valid AgentOutput."""
|
||||
llm = AsyncMock(spec=BaseChatModel)
|
||||
llm.model = model_name
|
||||
llm._verified_api_keys = True
|
||||
llm.provider = 'mock'
|
||||
llm.name = model_name
|
||||
llm.model_name = model_name
|
||||
|
||||
default_done_action = """
|
||||
{
|
||||
"thinking": "null",
|
||||
"evaluation_previous_goal": "Successfully completed the task",
|
||||
"memory": "Task completed",
|
||||
"next_goal": "Task completed",
|
||||
"action": [
|
||||
{
|
||||
"done": {
|
||||
"text": "Task completed successfully",
|
||||
"success": true
|
||||
}
|
||||
}
|
||||
]
|
||||
}
|
||||
"""
|
||||
|
||||
# Capture the agent reference for use in the closure
|
||||
captured_agent = agent
|
||||
|
||||
async def mock_ainvoke(*args, **kwargs):
|
||||
# Get the output format from kwargs and use it to parse
|
||||
output_format = kwargs.get('output_format')
|
||||
if output_format is not None:
|
||||
parsed = output_format.model_validate_json(default_done_action)
|
||||
return ChatInvokeCompletion(completion=parsed, usage=None)
|
||||
# Fallback: use the agent's AgentOutput type
|
||||
parsed = captured_agent.AgentOutput.model_validate_json(default_done_action)
|
||||
return ChatInvokeCompletion(completion=parsed, usage=None)
|
||||
|
||||
llm.ainvoke.side_effect = mock_ainvoke
|
||||
return llm
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_get_model_output_switches_to_fallback_on_rate_limit(self, browser_session):
|
||||
"""Test that get_model_output automatically switches to fallback on rate limit."""
|
||||
from browser_use import Agent
|
||||
|
||||
# Create agent first with a working mock LLM
|
||||
placeholder = create_mock_llm('placeholder')
|
||||
agent = Agent(task='Test task', llm=placeholder, browser_session=browser_session)
|
||||
|
||||
# Create a failing primary and succeeding fallback
|
||||
primary = self._create_failing_mock_llm(
|
||||
'primary-model',
|
||||
fail_with=ModelRateLimitError,
|
||||
fail_status_code=429,
|
||||
fail_message='Rate limit exceeded',
|
||||
)
|
||||
fallback = self._create_succeeding_mock_llm('fallback-model', agent)
|
||||
|
||||
# Replace the LLM and set up fallback
|
||||
agent.llm = primary
|
||||
agent._original_llm = primary
|
||||
agent._fallback_llm = fallback
|
||||
|
||||
from browser_use.llm.messages import BaseMessage, UserMessage
|
||||
|
||||
messages: list[BaseMessage] = [UserMessage(content='Test message')]
|
||||
|
||||
# This should switch to fallback and succeed
|
||||
result = await agent.get_model_output(messages)
|
||||
|
||||
assert result is not None
|
||||
assert agent.llm is fallback
|
||||
assert agent._using_fallback_llm is True
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_get_model_output_raises_when_no_fallback(self, browser_session):
|
||||
"""Test that get_model_output raises error when no fallback is configured."""
|
||||
from browser_use import Agent
|
||||
|
||||
# Create agent first with a working mock LLM
|
||||
placeholder = create_mock_llm('placeholder')
|
||||
agent = Agent(task='Test task', llm=placeholder, browser_session=browser_session)
|
||||
|
||||
# Replace with failing LLM
|
||||
primary = self._create_failing_mock_llm(
|
||||
'primary-model',
|
||||
fail_with=ModelRateLimitError,
|
||||
fail_status_code=429,
|
||||
fail_message='Rate limit exceeded',
|
||||
)
|
||||
agent.llm = primary
|
||||
agent._original_llm = primary
|
||||
agent._fallback_llm = None # No fallback
|
||||
|
||||
from browser_use.llm.messages import BaseMessage, UserMessage
|
||||
|
||||
messages: list[BaseMessage] = [UserMessage(content='Test message')]
|
||||
|
||||
# This should raise since no fallback is configured
|
||||
with pytest.raises(ModelRateLimitError):
|
||||
await agent.get_model_output(messages)
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_get_model_output_raises_when_fallback_also_fails(self, browser_session):
|
||||
"""Test that error is raised when fallback also fails."""
|
||||
from browser_use import Agent
|
||||
|
||||
# Create agent first with a working mock LLM
|
||||
placeholder = create_mock_llm('placeholder')
|
||||
agent = Agent(task='Test task', llm=placeholder, browser_session=browser_session)
|
||||
|
||||
# Both models fail
|
||||
primary = self._create_failing_mock_llm('primary', fail_with=ModelRateLimitError, fail_status_code=429)
|
||||
fallback = self._create_failing_mock_llm('fallback', fail_with=ModelProviderError, fail_status_code=503)
|
||||
|
||||
agent.llm = primary
|
||||
agent._original_llm = primary
|
||||
agent._fallback_llm = fallback
|
||||
|
||||
from browser_use.llm.messages import BaseMessage, UserMessage
|
||||
|
||||
messages: list[BaseMessage] = [UserMessage(content='Test message')]
|
||||
|
||||
# Should fail after fallback also fails
|
||||
with pytest.raises((ModelRateLimitError, ModelProviderError)):
|
||||
await agent.get_model_output(messages)
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
pytest.main([__file__, '-v'])
|
||||
@@ -85,6 +85,7 @@ class TestChatBrowserUseRetries:
|
||||
async def test_no_retry_on_401(self, mock_env):
|
||||
"""Test that 401 errors do NOT trigger retries."""
|
||||
from browser_use.llm.browser_use.chat import ChatBrowserUse
|
||||
from browser_use.llm.exceptions import ModelProviderError
|
||||
from browser_use.llm.messages import UserMessage
|
||||
|
||||
attempt_count = 0
|
||||
@@ -106,7 +107,7 @@ class TestChatBrowserUseRetries:
|
||||
|
||||
client = ChatBrowserUse(retry_base_delay=0.01)
|
||||
|
||||
with pytest.raises(ValueError, match='Invalid API key'):
|
||||
with pytest.raises(ModelProviderError, match='Invalid API key'):
|
||||
await client.ainvoke([UserMessage(content='test')])
|
||||
|
||||
# Should only attempt once (no retries for 401)
|
||||
@@ -148,6 +149,7 @@ class TestChatBrowserUseRetries:
|
||||
async def test_max_retries_exhausted(self, mock_env):
|
||||
"""Test that error is raised after max retries exhausted."""
|
||||
from browser_use.llm.browser_use.chat import ChatBrowserUse
|
||||
from browser_use.llm.exceptions import ModelProviderError
|
||||
from browser_use.llm.messages import UserMessage
|
||||
|
||||
attempt_count = 0
|
||||
@@ -169,7 +171,7 @@ class TestChatBrowserUseRetries:
|
||||
|
||||
client = ChatBrowserUse(max_retries=3, retry_base_delay=0.01)
|
||||
|
||||
with pytest.raises(ValueError, match='API request failed'):
|
||||
with pytest.raises(ModelProviderError, match='Server error'):
|
||||
await client.ainvoke([UserMessage(content='test')])
|
||||
|
||||
# Should have attempted max_retries times
|
||||
|
||||
Reference in New Issue
Block a user