mirror of
https://github.com/xtekky/gpt4free.git
synced 2025-12-03 13:34:36 +00:00
Support audio model in Azure provider
This commit is contained in:
@@ -4,7 +4,7 @@ import os
|
||||
import json
|
||||
|
||||
from ...typing import Messages, AsyncResult
|
||||
from ...errors import MissingAuthError
|
||||
from ...errors import MissingAuthError, ModelNotFoundError
|
||||
from ..template import OpenaiTemplate
|
||||
|
||||
class Azure(OpenaiTemplate):
|
||||
@@ -15,9 +15,20 @@ class Azure(OpenaiTemplate):
|
||||
active_by_default = True
|
||||
login_url = "https://discord.gg/qXA4Wf4Fsm"
|
||||
routes: dict[str, str] = {}
|
||||
audio_models = ["gpt-4o-mini-audio-preview"]
|
||||
model_extra_body = {
|
||||
"gpt-4o-mini-audio-preview": {
|
||||
"audio": {
|
||||
"voice": "alloy",
|
||||
"format": "mp3"
|
||||
},
|
||||
"modalities": ["text", "audio"],
|
||||
"stream": False
|
||||
}
|
||||
}
|
||||
|
||||
@classmethod
|
||||
def get_models(cls, **kwargs) -> list[str]:
|
||||
def get_models(cls, api_key: str = None, **kwargs) -> list[str]:
|
||||
routes = os.environ.get("AZURE_ROUTES")
|
||||
if routes:
|
||||
try:
|
||||
@@ -27,7 +38,7 @@ class Azure(OpenaiTemplate):
|
||||
cls.routes = routes
|
||||
if cls.routes:
|
||||
return list(cls.routes.keys())
|
||||
return super().get_models(**kwargs)
|
||||
return super().get_models(api_key=api_key, **kwargs)
|
||||
|
||||
@classmethod
|
||||
async def create_async_generator(
|
||||
@@ -40,6 +51,9 @@ class Azure(OpenaiTemplate):
|
||||
) -> AsyncResult:
|
||||
if not model:
|
||||
model = os.environ.get("AZURE_DEFAULT_MODEL", cls.default_model)
|
||||
if model in cls.model_extra_body:
|
||||
for key, value in cls.model_extra_body[model].items():
|
||||
kwargs.setdefault(key, value)
|
||||
if not api_key:
|
||||
raise ValueError(f"API key is required for Azure provider. Ask for API key in the {cls.login_url} Discord server.")
|
||||
if not api_endpoint:
|
||||
@@ -47,7 +61,7 @@ class Azure(OpenaiTemplate):
|
||||
cls.get_models()
|
||||
api_endpoint = cls.routes.get(model)
|
||||
if cls.routes and not api_endpoint:
|
||||
raise ValueError(f"No API endpoint found for model: {model}")
|
||||
raise ModelNotFoundError(f"No API endpoint found for model: {model}")
|
||||
if not api_endpoint:
|
||||
api_endpoint = os.environ.get("AZURE_API_ENDPOINT")
|
||||
try:
|
||||
|
||||
@@ -7,6 +7,7 @@ from ..base_provider import AsyncGeneratorProvider, ProviderModelMixin, RaiseErr
|
||||
from ...typing import Union, AsyncResult, Messages, MediaListType
|
||||
from ...requests import StreamSession, raise_for_status
|
||||
from ...image import use_aspect_ratio
|
||||
from ...image.copy_images import save_response_media
|
||||
from ...providers.response import FinishReason, ToolCalls, Usage, ImageResponse, ProviderInfo
|
||||
from ...tools.media import render_messages
|
||||
from ...errors import MissingAuthError, ResponseError
|
||||
@@ -62,7 +63,7 @@ class OpenaiTemplate(AsyncGeneratorProvider, ProviderModelMixin, RaiseErrorMixin
|
||||
max_tokens: int = None,
|
||||
top_p: float = None,
|
||||
stop: Union[str, list[str]] = None,
|
||||
stream: bool = False,
|
||||
stream: bool = None,
|
||||
prompt: str = None,
|
||||
headers: dict = None,
|
||||
impersonate: str = None,
|
||||
@@ -115,7 +116,7 @@ class OpenaiTemplate(AsyncGeneratorProvider, ProviderModelMixin, RaiseErrorMixin
|
||||
max_tokens=max_tokens,
|
||||
top_p=top_p,
|
||||
stop=stop,
|
||||
stream=stream,
|
||||
stream="audio" not in extra_parameters if stream is None else stream,
|
||||
**extra_parameters,
|
||||
**extra_body
|
||||
)
|
||||
@@ -136,10 +137,18 @@ class OpenaiTemplate(AsyncGeneratorProvider, ProviderModelMixin, RaiseErrorMixin
|
||||
yield Usage(**data["usage"])
|
||||
if "choices" in data:
|
||||
choice = next(iter(data["choices"]), None)
|
||||
if choice and "content" in choice["message"] and choice["message"]["content"]:
|
||||
yield choice["message"]["content"].strip()
|
||||
if "tool_calls" in choice["message"]:
|
||||
yield ToolCalls(choice["message"]["tool_calls"])
|
||||
message = choice.get("message", {})
|
||||
if choice and "content" in message and message["content"]:
|
||||
yield message["content"].strip()
|
||||
if "tool_calls" in message:
|
||||
yield ToolCalls(message["tool_calls"])
|
||||
audio = message.get("audio", {})
|
||||
if "data" in audio:
|
||||
async for chunk in save_response_media(audio["data"], prompt, [model, extra_body.get("audio", {}).get("voice")]):
|
||||
yield chunk
|
||||
if "transcript" in audio:
|
||||
yield "\n\n"
|
||||
yield audio["transcript"]
|
||||
if choice and "finish_reason" in choice and choice["finish_reason"] is not None:
|
||||
yield FinishReason(choice["finish_reason"])
|
||||
return
|
||||
|
||||
@@ -23,7 +23,7 @@ def convert_to_provider(provider: str) -> ProviderType:
|
||||
|
||||
def get_model_and_provider(model : Union[Model, str],
|
||||
provider : Union[ProviderType, str, None],
|
||||
stream : bool,
|
||||
stream : bool = False,
|
||||
ignore_working: bool = False,
|
||||
ignore_stream: bool = False,
|
||||
logging: bool = True,
|
||||
|
||||
@@ -149,7 +149,6 @@ class Api:
|
||||
"model": model,
|
||||
"provider": provider,
|
||||
"messages": messages,
|
||||
"stream": True,
|
||||
"ignore_stream": True,
|
||||
**kwargs
|
||||
}
|
||||
@@ -166,8 +165,6 @@ class Api:
|
||||
try:
|
||||
model, provider_handler = get_model_and_provider(
|
||||
kwargs.get("model"), provider,
|
||||
stream=True,
|
||||
ignore_stream=True,
|
||||
has_images="media" in kwargs,
|
||||
)
|
||||
if "user" in kwargs:
|
||||
|
||||
@@ -112,7 +112,7 @@ def get_filename(tags: list[str], alt: str, extension: str, image: str) -> str:
|
||||
return "".join((
|
||||
f"{int(time.time())}_",
|
||||
f"{secure_filename(tags + alt)}_" if alt else secure_filename(tags),
|
||||
hashlib.sha256(image.encode()).hexdigest()[:16],
|
||||
hashlib.sha256(str(time.time()).encode() if image is None else image.encode()).hexdigest()[:16],
|
||||
extension
|
||||
))
|
||||
|
||||
|
||||
@@ -292,7 +292,7 @@ class AsyncGeneratorProvider(AbstractProvider):
|
||||
cls,
|
||||
model: str,
|
||||
messages: Messages,
|
||||
stream: bool = True,
|
||||
stream: bool = None,
|
||||
timeout: int = None,
|
||||
**kwargs
|
||||
) -> CreateResult:
|
||||
@@ -312,7 +312,7 @@ class AsyncGeneratorProvider(AbstractProvider):
|
||||
"""
|
||||
return to_sync_generator(
|
||||
cls.create_async_generator(model, messages, stream=stream, **kwargs),
|
||||
stream=stream,
|
||||
stream=stream is not False,
|
||||
timeout=timeout
|
||||
)
|
||||
|
||||
@@ -321,7 +321,6 @@ class AsyncGeneratorProvider(AbstractProvider):
|
||||
async def create_async_generator(
|
||||
model: str,
|
||||
messages: Messages,
|
||||
stream: bool = True,
|
||||
**kwargs
|
||||
) -> AsyncResult:
|
||||
"""
|
||||
|
||||
@@ -16,4 +16,5 @@ services:
|
||||
stop_grace_period: 2m
|
||||
restart: on-failure
|
||||
volumes:
|
||||
- /var/win:/storage
|
||||
- /var/win:/storage
|
||||
- ./:/data
|
||||
Reference in New Issue
Block a user