Enhance MCP server tests to reflect updated tool count; improve model fetching with timeout handling in providers

This commit is contained in:
hlohaus
2025-11-02 08:01:20 +01:00
parent 006b8c8d50
commit af56ac0c03
6 changed files with 19 additions and 16 deletions

View File

@@ -22,7 +22,7 @@ class TestMCPServer(unittest.IsolatedAsyncioTestCase):
server = MCPServer()
self.assertIsNotNone(server)
self.assertEqual(server.server_info["name"], "gpt4free-mcp-server")
self.assertEqual(len(server.tools), 3)
self.assertEqual(len(server.tools), 5)
self.assertIn('web_search', server.tools)
self.assertIn('web_scrape', server.tools)
self.assertIn('image_generation', server.tools)
@@ -57,7 +57,7 @@ class TestMCPServer(unittest.IsolatedAsyncioTestCase):
self.assertEqual(response.id, 2)
self.assertIsNotNone(response.result)
self.assertIn("tools", response.result)
self.assertEqual(len(response.result["tools"]), 3)
self.assertEqual(len(response.result["tools"]), 5)
# Check tool structure
tool_names = [tool["name"] for tool in response.result["tools"]]

View File

@@ -1,6 +1,6 @@
import unittest
from typing import Type
import asyncio
from requests.exceptions import RequestException
from g4f.models import __models__
from g4f.providers.base_provider import BaseProvider, ProviderModelMixin
@@ -15,12 +15,15 @@ class TestProviderHasModel(unittest.TestCase):
if provider.needs_auth:
continue
if issubclass(provider, ProviderModelMixin):
provider.get_models() # Update models
if model.name in provider.model_aliases:
model_name = provider.model_aliases[model.name]
else:
model_name = model.get_long_name()
self.provider_has_model(provider, model_name)
try:
provider.get_models(timeout=5) # Update models
if model.name in provider.model_aliases:
model_name = provider.model_aliases[model.name]
else:
model_name = model.get_long_name()
self.provider_has_model(provider, model_name)
except RequestException:
continue
def provider_has_model(self, provider: Type[BaseProvider], model: str):
if provider.__name__ not in self.cache:

View File

@@ -502,7 +502,7 @@ class LMArena(AsyncGeneratorProvider, ProviderModelMixin, AuthFileMixin):
_models_loaded = False
@classmethod
def get_models(cls) -> list[str]:
def get_models(cls, timeout: int = None) -> list[str]:
if not cls._models_loaded and has_curl_cffi:
cache_file = cls.get_cache_file()
args = {}
@@ -516,7 +516,7 @@ class LMArena(AsyncGeneratorProvider, ProviderModelMixin, AuthFileMixin):
args = {}
if not args:
return cls.models
response = curl_cffi.get(f"{cls.url}/?mode=direct", **args)
response = curl_cffi.get(f"{cls.url}/?mode=direct", **args, timeout=timeout)
if response.ok:
for line in response.text.splitlines():
if "initialModels" in line:

View File

@@ -31,7 +31,7 @@ class OpenaiTemplate(AsyncGeneratorProvider, ProviderModelMixin, RaiseErrorMixin
max_tokens: int = None
@classmethod
def get_models(cls, api_key: str = None, api_base: str = None) -> list[str]:
def get_models(cls, api_key: str = None, api_base: str = None, timeout: int = None) -> list[str]:
if not cls.models:
try:
if api_base is None:
@@ -42,7 +42,7 @@ class OpenaiTemplate(AsyncGeneratorProvider, ProviderModelMixin, RaiseErrorMixin
api_key = AuthManager.load_api_key(cls)
if cls.models_needs_auth and not api_key:
raise MissingAuthError('Add a "api_key"')
response = requests.get(f"{api_base}/models", headers=cls.get_headers(False, api_key), verify=cls.ssl)
response = requests.get(f"{api_base}/models", headers=cls.get_headers(False, api_key), verify=cls.ssl, timeout=timeout)
raise_for_status(response)
data = response.json()
data = data.get("data", data.get("models")) if isinstance(data, dict) else data

View File

@@ -18,8 +18,6 @@ from dataclasses import dataclass
from ..debug import enable_logging
enable_logging()
from .tools import MarkItDownTool, TextToAudioTool, WebSearchTool, WebScrapeTool, ImageGenerationTool
from .tools import WebSearchTool, WebScrapeTool, ImageGenerationTool
@@ -214,6 +212,8 @@ class MCPServer:
sys.stderr.write("Error: aiohttp is required for HTTP transport\n")
sys.stderr.write("Install it with: pip install aiohttp\n")
sys.exit(1)
enable_logging()
async def handle_mcp_request(request: web.Request) -> web.Response:
nonlocal origin

View File

@@ -213,7 +213,7 @@ gpt_4o = VisionModel(
gpt_4o_mini = Model(
name = 'gpt-4o-mini',
base_provider = 'OpenAI',
best_provider = IterListProvider([Chatai, OIVSCodeSer2, Startnest, OpenaiChat, OIVSCodeSer0501])
best_provider = IterListProvider([Chatai, OIVSCodeSer2, Startnest, OpenaiChat])
)
gpt_4o_mini_audio = AudioModel(