Support audio model in Azure provider

This commit is contained in:
hlohaus
2025-07-12 19:41:23 +02:00
parent ebff7d51ab
commit c4b18df769
7 changed files with 39 additions and 19 deletions

View File

@@ -4,7 +4,7 @@ import os
import json
from ...typing import Messages, AsyncResult
from ...errors import MissingAuthError
from ...errors import MissingAuthError, ModelNotFoundError
from ..template import OpenaiTemplate
class Azure(OpenaiTemplate):
@@ -15,9 +15,20 @@ class Azure(OpenaiTemplate):
active_by_default = True
login_url = "https://discord.gg/qXA4Wf4Fsm"
routes: dict[str, str] = {}
audio_models = ["gpt-4o-mini-audio-preview"]
model_extra_body = {
"gpt-4o-mini-audio-preview": {
"audio": {
"voice": "alloy",
"format": "mp3"
},
"modalities": ["text", "audio"],
"stream": False
}
}
@classmethod
def get_models(cls, **kwargs) -> list[str]:
def get_models(cls, api_key: str = None, **kwargs) -> list[str]:
routes = os.environ.get("AZURE_ROUTES")
if routes:
try:
@@ -27,7 +38,7 @@ class Azure(OpenaiTemplate):
cls.routes = routes
if cls.routes:
return list(cls.routes.keys())
return super().get_models(**kwargs)
return super().get_models(api_key=api_key, **kwargs)
@classmethod
async def create_async_generator(
@@ -40,6 +51,9 @@ class Azure(OpenaiTemplate):
) -> AsyncResult:
if not model:
model = os.environ.get("AZURE_DEFAULT_MODEL", cls.default_model)
if model in cls.model_extra_body:
for key, value in cls.model_extra_body[model].items():
kwargs.setdefault(key, value)
if not api_key:
raise ValueError(f"API key is required for Azure provider. Ask for API key in the {cls.login_url} Discord server.")
if not api_endpoint:
@@ -47,7 +61,7 @@ class Azure(OpenaiTemplate):
cls.get_models()
api_endpoint = cls.routes.get(model)
if cls.routes and not api_endpoint:
raise ValueError(f"No API endpoint found for model: {model}")
raise ModelNotFoundError(f"No API endpoint found for model: {model}")
if not api_endpoint:
api_endpoint = os.environ.get("AZURE_API_ENDPOINT")
try:

View File

@@ -7,6 +7,7 @@ from ..base_provider import AsyncGeneratorProvider, ProviderModelMixin, RaiseErr
from ...typing import Union, AsyncResult, Messages, MediaListType
from ...requests import StreamSession, raise_for_status
from ...image import use_aspect_ratio
from ...image.copy_images import save_response_media
from ...providers.response import FinishReason, ToolCalls, Usage, ImageResponse, ProviderInfo
from ...tools.media import render_messages
from ...errors import MissingAuthError, ResponseError
@@ -62,7 +63,7 @@ class OpenaiTemplate(AsyncGeneratorProvider, ProviderModelMixin, RaiseErrorMixin
max_tokens: int = None,
top_p: float = None,
stop: Union[str, list[str]] = None,
stream: bool = False,
stream: bool = None,
prompt: str = None,
headers: dict = None,
impersonate: str = None,
@@ -115,7 +116,7 @@ class OpenaiTemplate(AsyncGeneratorProvider, ProviderModelMixin, RaiseErrorMixin
max_tokens=max_tokens,
top_p=top_p,
stop=stop,
stream=stream,
stream="audio" not in extra_parameters if stream is None else stream,
**extra_parameters,
**extra_body
)
@@ -136,10 +137,18 @@ class OpenaiTemplate(AsyncGeneratorProvider, ProviderModelMixin, RaiseErrorMixin
yield Usage(**data["usage"])
if "choices" in data:
choice = next(iter(data["choices"]), None)
if choice and "content" in choice["message"] and choice["message"]["content"]:
yield choice["message"]["content"].strip()
if "tool_calls" in choice["message"]:
yield ToolCalls(choice["message"]["tool_calls"])
message = choice.get("message", {})
if choice and "content" in message and message["content"]:
yield message["content"].strip()
if "tool_calls" in message:
yield ToolCalls(message["tool_calls"])
audio = message.get("audio", {})
if "data" in audio:
async for chunk in save_response_media(audio["data"], prompt, [model, extra_body.get("audio", {}).get("voice")]):
yield chunk
if "transcript" in audio:
yield "\n\n"
yield audio["transcript"]
if choice and "finish_reason" in choice and choice["finish_reason"] is not None:
yield FinishReason(choice["finish_reason"])
return

View File

@@ -23,7 +23,7 @@ def convert_to_provider(provider: str) -> ProviderType:
def get_model_and_provider(model : Union[Model, str],
provider : Union[ProviderType, str, None],
stream : bool,
stream : bool = False,
ignore_working: bool = False,
ignore_stream: bool = False,
logging: bool = True,

View File

@@ -149,7 +149,6 @@ class Api:
"model": model,
"provider": provider,
"messages": messages,
"stream": True,
"ignore_stream": True,
**kwargs
}
@@ -166,8 +165,6 @@ class Api:
try:
model, provider_handler = get_model_and_provider(
kwargs.get("model"), provider,
stream=True,
ignore_stream=True,
has_images="media" in kwargs,
)
if "user" in kwargs:

View File

@@ -112,7 +112,7 @@ def get_filename(tags: list[str], alt: str, extension: str, image: str) -> str:
return "".join((
f"{int(time.time())}_",
f"{secure_filename(tags + alt)}_" if alt else secure_filename(tags),
hashlib.sha256(image.encode()).hexdigest()[:16],
hashlib.sha256(str(time.time()).encode() if image is None else image.encode()).hexdigest()[:16],
extension
))

View File

@@ -292,7 +292,7 @@ class AsyncGeneratorProvider(AbstractProvider):
cls,
model: str,
messages: Messages,
stream: bool = True,
stream: bool = None,
timeout: int = None,
**kwargs
) -> CreateResult:
@@ -312,7 +312,7 @@ class AsyncGeneratorProvider(AbstractProvider):
"""
return to_sync_generator(
cls.create_async_generator(model, messages, stream=stream, **kwargs),
stream=stream,
stream=stream is not False,
timeout=timeout
)
@@ -321,7 +321,6 @@ class AsyncGeneratorProvider(AbstractProvider):
async def create_async_generator(
model: str,
messages: Messages,
stream: bool = True,
**kwargs
) -> AsyncResult:
"""

View File

@@ -16,4 +16,5 @@ services:
stop_grace_period: 2m
restart: on-failure
volumes:
- /var/win:/storage
- /var/win:/storage
- ./:/data