Skip to content
Draft
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Prev Previous commit
Next Next commit
wip 1
  • Loading branch information
blanch0t committed Oct 24, 2025
commit 3af3254287815137816dfa590a8813c27c5ff0bc
10 changes: 9 additions & 1 deletion api/endpoints/chat.py
Original file line number Diff line number Diff line change
Expand Up @@ -75,9 +75,17 @@ async def retrieval_augmentation_generation(
additional_data = {"search_results": results} if results else {}

user_priority = getattr(user, "priority", 0)
organization = getattr(user, "organization", None)
request_mode = body.request_mode

try:
client, task_metrics = await invoke_model_request(model_name=body["model"], endpoint=ENDPOINT__CHAT_COMPLETIONS, user_priority=user_priority)
client, task_metrics = await invoke_model_request(
model_name=body.model,
endpoint=ENDPOINT__CHAT_COMPLETIONS,
user_priority=user_priority,
request_mode=request_mode,
organization=organization,
)
except TaskFailedException as e:
return JSONResponse(content=e.detail, status_code=e.status_code)

Expand Down
15 changes: 14 additions & 1 deletion api/helpers/models/routers/_modelrouter.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,13 +5,14 @@
import time

from fastapi import HTTPException
import random

from api.clients.model import BaseModelClient as ModelClient
from api.helpers.models.routers.strategies import LeastBusyRoutingStrategy, RoundRobinRoutingStrategy, ShuffleRoutingStrategy
from api.schemas.core.configuration import Model as ModelRouterSchema
from api.schemas.core.configuration import RoutingStrategy
from api.schemas.models import ModelType
from api.utils.exceptions import WrongModelTypeException
from api.utils.exceptions import WrongModelTypeException, ModelNotProvidedByOrganizationException
from api.utils.tracked_cycle import TrackedCycle
from api.utils.variables import ENDPOINT__AUDIO_TRANSCRIPTIONS, ENDPOINT__CHAT_COMPLETIONS, ENDPOINT__EMBEDDINGS, ENDPOINT__OCR, ENDPOINT__RERANK

Expand Down Expand Up @@ -144,6 +145,18 @@ def get_client(self, endpoint: str) -> tuple[ModelClient, float | None]:

return client, metric

def get_client_from_org(self, organization: str, endpoint: str) -> tuple[ModelClient, float | None]:
candidates = []
for provider in self._providers:
if provider.organization == organization:
candidates.append(provider)
try:
choice = random.choice(candidates)
choice.endpoint = endpoint
return choice, None
except IndexError:
raise ModelNotProvidedByOrganizationException(self.name, organization)

async def get_clients(self):
"""
Return the current list of ModelClient thread-safely.
Expand Down
8 changes: 8 additions & 0 deletions api/schemas/chat.py
Original file line number Diff line number Diff line change
Expand Up @@ -51,6 +51,14 @@ class ChatCompletionRequest(BaseModel):
search: bool = Field(default=False) # fmt: off
search_args: ChatSearchArgs | None = Field(default=None) # fmt: off

request_mode: Literal["shared", "private-only", "private-first"] = Field(
default="shared",
description="Determines which provider pool to use. Options are: "
"`shared` (default, use shared pool), "
"`private-only` (use only the org-specific provider), "
"`private-first` (try private first, fallback to shared if unavailable).",
)

@model_validator(mode="after")
def validate_model(cls, values):
if values.search:
Expand Down
26 changes: 19 additions & 7 deletions api/services/model_invocation.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,21 +13,23 @@
from api.schemas.core.configuration import ModelProvider as ModelClientSchema
from api.schemas.usage import TaskMetrics
from api.clients.model import BaseModelClient
from api.tasks.model import invoke_model_task
from api.tasks.model import invoke_shared_model_task, invoke_private_model_task
from api.tasks.celery_app import queue_name_for_model, task_priority_from_user_priority
from api.utils.tracked_cycle import TrackedCycle
<<<<<<< HEAD
from api.utils.exceptions import TaskFailedException
from api.utils.tracked_cycle import TrackedCycle
=======
from api.utils.exceptions import ModelNotProvidedByOrganizationException, TaskFailedException
>>>>>>> b1ba224 (wip 1)

logger = logging.getLogger(__name__)

settings = configuration.settings


async def invoke_model_request(
model_name: str,
endpoint: str,
user_priority: int | None = None,
model_name: str, endpoint: str, user_priority: int | None = None, request_mode: str = "shared", organization: str | None = None
) -> tuple[BaseModelClient, TaskMetrics]:
"""Invoke a model (non-streaming) returning (status_code, json_body).

Expand Down Expand Up @@ -55,10 +57,20 @@ async def invoke_model_request(
except Exception:
original_name = model_name # fallback; error will surface later if invalid

queue = queue_name_for_model(original_name)

# Submit task
async_result = invoke_model_task.apply_async(args=[router_schema, endpoint], queue=queue, priority=priority)
if "private" in request_mode:
try:
router.get_client_from_org(organization)
except ModelNotProvidedByOrganizationException:
raise
queue = queue_name_for_model(original_name, organization)
async_result = invoke_private_model_task.apply_async(
args=[router_schema, endpoint, request_mode, organization], queue=queue, priority=priority
)

else:
queue = queue_name_for_model(original_name)
async_result = invoke_shared_model_task.apply_async(args=[router_schema, endpoint], queue=queue, priority=priority)

# Wait for result using async polling
result, duration = await wait_for_task_result(async_result.id)
Expand Down
12 changes: 10 additions & 2 deletions api/tasks/celery_app.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,8 +34,16 @@
celery_app.conf.task_queue_max_priority = MAX_PRIORITY


def queue_name_for_model(router_name: str) -> str:
return f"{settings.celery_default_queue_prefix}{router_name}" if router_name else settings.celery_default_queue_prefix.rstrip(".")
def queue_name_for_model(router_name: str, org_name: str | None = None) -> str:
if not router_name:
return settings.celery_default_queue_prefix.rstrip(".")
base = f"{settings.celery_default_queue_prefix}{router_name}"
return f"{base}.{org_name}" if org_name else base


def shared_queue_name_from_private_one(private_queue_name: str) -> str:
# Strip org suffix (everything after (including) the last ".")
return private_queue_name.rsplit(".", 1)[0]


def task_priority_from_user_priority(user_priority: int) -> int:
Expand Down
72 changes: 69 additions & 3 deletions api/tasks/model.py
Original file line number Diff line number Diff line change
@@ -1,18 +1,22 @@
from typing import Any
from typing import Any, Dict

from billiard.exceptions import SoftTimeLimitExceeded
from celery.exceptions import MaxRetriesExceededError, Retry

from api.helpers.models.routers._modelrouter import ModelRouter
from billiard.exceptions import SoftTimeLimitExceeded
from celery.exceptions import Retry, MaxRetriesExceededError

from api.tasks.celery_app import celery_app, shared_queue_name_from_private_one
from api.schemas.core.configuration import Model as ModelRouterSchema
from api.tasks.celery_app import celery_app
from api.utils.configuration import configuration

settings = configuration.settings


@celery_app.task(name="model.invoke", bind=True)
def invoke_model_task(self, router_schema: dict[str, Any], endpoint: str) -> dict[str, Any]:
@celery_app.task(name="model.invoke.shared", bind=True)
def invoke_shared_model_task(self, router_schema: Dict[str, Any], endpoint: str) -> Dict[str, Any]:
"""Invoke a model provider (non-streaming).

router_schema: serialized ModelRouterSchema schema (censored=False)
Expand Down Expand Up @@ -56,3 +60,65 @@ def invoke_model_task(self, router_schema: dict[str, Any], endpoint: str) -> dic
return {"status_code": 504, "requeue_count": self.request.retries, "body": {"detail": "Model invocation exceeded the soft time limit"}}
except Exception as e: # pragma: no cover - defensive
return {"status_code": 500, "requeue_count": self.request.retries, "body": {"detail": type(e).__name__}}


@celery_app.task(name="model.invoke.private", bind=True)
def invoke_private_model_task(self, router_schema: Dict[str, Any], endpoint: str, mode: str, organization: str) -> Dict[str, Any]:
"""
Private or private-first invocation for a specific provider.
mode ∈ {"private", "private-first"}
"""

# Reconstruct Pydantic Model from dict
try:
schema_obj = ModelRouterSchema(**router_schema)
except Exception:
# Backward compatibility: router_schema may use 'name' instead of 'model_name'
if "name" in router_schema and "model_name" not in router_schema:
router_schema["model_name"] = router_schema["name"]
schema_obj = ModelRouterSchema(**router_schema)

router = ModelRouter.from_schema(schema=schema_obj)

try:
client, performance_indicator = router.get_client_from_org(organization, endpoint=endpoint)
can_be_forwarded = client.apply_modelclient_policy(performance_indicator)

if can_be_forwarded:
return {
"status_code": 200,
"client": client.as_schema(censored=False).model_dump(),
"cycle_offset": router._cycle.offset,
"requeue_count": self.request.retries,
"performance_indicator": performance_indicator,
}

elif mode == "private-first":
current_queue = self.request.delivery_info.get("routing_key", "")
shared_queue = shared_queue_name_from_private_one(current_queue)

invoke_shared_model_task.apply_async(
args=[router_schema, endpoint],
queue=shared_queue,
priority=self.request.delivery_info.get("priority", 0),
)

return {
"status_code": 202,
"body": {"detail": f"Requeued from {current_queue} → {shared_queue}"},
}

else:
raise self.retry(
countdown=settings.celery_task_retry_countdown,
max_retries=settings.celery_task_max_retry,
)

except Retry:
raise
except MaxRetriesExceededError:
return {"status_code": 503, "body": {"detail": "Max retries exceeded"}}
except SoftTimeLimitExceeded:
return {"status_code": 504, "body": {"detail": "Soft time limit exceeded"}}
except Exception as e:
return {"status_code": 500, "body": {"detail": type(e).__name__}}
6 changes: 6 additions & 0 deletions api/utils/exceptions.py
Original file line number Diff line number Diff line change
Expand Up @@ -69,6 +69,12 @@ def __init__(self, detail: str = "Insufficient rights.") -> None:
super().__init__(status_code=403, detail=detail)


class ModelNotProvidedByOrganizationException(HTTPException):
def __init__(self, model: str, organization: str) -> None:
detail = f"Private mode not allowed because the requested model: {model} is not provided by your organization ({organization})"
super().__init__(status_code=403, detail=detail)


# 404
class CollectionNotFoundException(HTTPException):
def __init__(self, detail: str = "Collection not found.") -> None:
Expand Down