mirror of
https://github.com/xtekky/gpt4free.git
synced 2025-12-03 13:34:36 +00:00
Enhance MCP server tests to reflect updated tool count; improve model fetching with timeout handling in providers
This commit is contained in:
@@ -22,7 +22,7 @@ class TestMCPServer(unittest.IsolatedAsyncioTestCase):
|
||||
server = MCPServer()
|
||||
self.assertIsNotNone(server)
|
||||
self.assertEqual(server.server_info["name"], "gpt4free-mcp-server")
|
||||
self.assertEqual(len(server.tools), 3)
|
||||
self.assertEqual(len(server.tools), 5)
|
||||
self.assertIn('web_search', server.tools)
|
||||
self.assertIn('web_scrape', server.tools)
|
||||
self.assertIn('image_generation', server.tools)
|
||||
@@ -57,7 +57,7 @@ class TestMCPServer(unittest.IsolatedAsyncioTestCase):
|
||||
self.assertEqual(response.id, 2)
|
||||
self.assertIsNotNone(response.result)
|
||||
self.assertIn("tools", response.result)
|
||||
self.assertEqual(len(response.result["tools"]), 3)
|
||||
self.assertEqual(len(response.result["tools"]), 5)
|
||||
|
||||
# Check tool structure
|
||||
tool_names = [tool["name"] for tool in response.result["tools"]]
|
||||
|
||||
@@ -1,6 +1,6 @@
|
||||
import unittest
|
||||
from typing import Type
|
||||
import asyncio
|
||||
from requests.exceptions import RequestException
|
||||
|
||||
from g4f.models import __models__
|
||||
from g4f.providers.base_provider import BaseProvider, ProviderModelMixin
|
||||
@@ -15,12 +15,15 @@ class TestProviderHasModel(unittest.TestCase):
|
||||
if provider.needs_auth:
|
||||
continue
|
||||
if issubclass(provider, ProviderModelMixin):
|
||||
provider.get_models() # Update models
|
||||
if model.name in provider.model_aliases:
|
||||
model_name = provider.model_aliases[model.name]
|
||||
else:
|
||||
model_name = model.get_long_name()
|
||||
self.provider_has_model(provider, model_name)
|
||||
try:
|
||||
provider.get_models(timeout=5) # Update models
|
||||
if model.name in provider.model_aliases:
|
||||
model_name = provider.model_aliases[model.name]
|
||||
else:
|
||||
model_name = model.get_long_name()
|
||||
self.provider_has_model(provider, model_name)
|
||||
except RequestException:
|
||||
continue
|
||||
|
||||
def provider_has_model(self, provider: Type[BaseProvider], model: str):
|
||||
if provider.__name__ not in self.cache:
|
||||
|
||||
@@ -502,7 +502,7 @@ class LMArena(AsyncGeneratorProvider, ProviderModelMixin, AuthFileMixin):
|
||||
_models_loaded = False
|
||||
|
||||
@classmethod
|
||||
def get_models(cls) -> list[str]:
|
||||
def get_models(cls, timeout: int = None) -> list[str]:
|
||||
if not cls._models_loaded and has_curl_cffi:
|
||||
cache_file = cls.get_cache_file()
|
||||
args = {}
|
||||
@@ -516,7 +516,7 @@ class LMArena(AsyncGeneratorProvider, ProviderModelMixin, AuthFileMixin):
|
||||
args = {}
|
||||
if not args:
|
||||
return cls.models
|
||||
response = curl_cffi.get(f"{cls.url}/?mode=direct", **args)
|
||||
response = curl_cffi.get(f"{cls.url}/?mode=direct", **args, timeout=timeout)
|
||||
if response.ok:
|
||||
for line in response.text.splitlines():
|
||||
if "initialModels" in line:
|
||||
|
||||
@@ -31,7 +31,7 @@ class OpenaiTemplate(AsyncGeneratorProvider, ProviderModelMixin, RaiseErrorMixin
|
||||
max_tokens: int = None
|
||||
|
||||
@classmethod
|
||||
def get_models(cls, api_key: str = None, api_base: str = None) -> list[str]:
|
||||
def get_models(cls, api_key: str = None, api_base: str = None, timeout: int = None) -> list[str]:
|
||||
if not cls.models:
|
||||
try:
|
||||
if api_base is None:
|
||||
@@ -42,7 +42,7 @@ class OpenaiTemplate(AsyncGeneratorProvider, ProviderModelMixin, RaiseErrorMixin
|
||||
api_key = AuthManager.load_api_key(cls)
|
||||
if cls.models_needs_auth and not api_key:
|
||||
raise MissingAuthError('Add a "api_key"')
|
||||
response = requests.get(f"{api_base}/models", headers=cls.get_headers(False, api_key), verify=cls.ssl)
|
||||
response = requests.get(f"{api_base}/models", headers=cls.get_headers(False, api_key), verify=cls.ssl, timeout=timeout)
|
||||
raise_for_status(response)
|
||||
data = response.json()
|
||||
data = data.get("data", data.get("models")) if isinstance(data, dict) else data
|
||||
|
||||
@@ -18,8 +18,6 @@ from dataclasses import dataclass
|
||||
|
||||
from ..debug import enable_logging
|
||||
|
||||
enable_logging()
|
||||
|
||||
from .tools import MarkItDownTool, TextToAudioTool, WebSearchTool, WebScrapeTool, ImageGenerationTool
|
||||
from .tools import WebSearchTool, WebScrapeTool, ImageGenerationTool
|
||||
|
||||
@@ -214,6 +212,8 @@ class MCPServer:
|
||||
sys.stderr.write("Error: aiohttp is required for HTTP transport\n")
|
||||
sys.stderr.write("Install it with: pip install aiohttp\n")
|
||||
sys.exit(1)
|
||||
|
||||
enable_logging()
|
||||
|
||||
async def handle_mcp_request(request: web.Request) -> web.Response:
|
||||
nonlocal origin
|
||||
|
||||
@@ -213,7 +213,7 @@ gpt_4o = VisionModel(
|
||||
gpt_4o_mini = Model(
|
||||
name = 'gpt-4o-mini',
|
||||
base_provider = 'OpenAI',
|
||||
best_provider = IterListProvider([Chatai, OIVSCodeSer2, Startnest, OpenaiChat, OIVSCodeSer0501])
|
||||
best_provider = IterListProvider([Chatai, OIVSCodeSer2, Startnest, OpenaiChat])
|
||||
)
|
||||
|
||||
gpt_4o_mini_audio = AudioModel(
|
||||
|
||||
Reference in New Issue
Block a user