Spaces:
Running
Running
Update main.py
Browse files
main.py
CHANGED
|
@@ -501,6 +501,8 @@ async def get_completion(payload: Payload, request: Request, authenticated: bool
|
|
| 501 |
|
| 502 |
env_vars = get_env_vars()
|
| 503 |
|
|
|
|
|
|
|
| 504 |
if model_to_use in mistral_models:
|
| 505 |
endpoint = env_vars['mistral_api']
|
| 506 |
custom_headers = {
|
|
@@ -524,6 +526,7 @@ async def get_completion(payload: Payload, request: Request, authenticated: bool
|
|
| 524 |
custom_headers = {
|
| 525 |
"Authorization": f"Bearer {env_vars['gemini_key']}"
|
| 526 |
}
|
|
|
|
| 527 |
else:
|
| 528 |
endpoint = env_vars['secret_api_endpoint']
|
| 529 |
custom_headers = {
|
|
@@ -532,12 +535,12 @@ async def get_completion(payload: Payload, request: Request, authenticated: bool
|
|
| 532 |
"Referer": header_url
|
| 533 |
}
|
| 534 |
|
| 535 |
-
print(f"Using endpoint: {endpoint} for model: {model_to_use}")
|
| 536 |
|
| 537 |
async def real_time_stream_generator():
|
| 538 |
try:
|
| 539 |
async with httpx.AsyncClient(timeout=60.0) as client:
|
| 540 |
-
async with client.stream("POST", f"{endpoint}
|
| 541 |
if response.status_code >= 400:
|
| 542 |
error_messages = {
|
| 543 |
422: "Unprocessable entity. Check your payload.",
|
|
@@ -576,7 +579,6 @@ async def get_completion(payload: Payload, request: Request, authenticated: bool
|
|
| 576 |
async for chunk in real_time_stream_generator():
|
| 577 |
response_content.append(chunk)
|
| 578 |
return JSONResponse(content=json.loads(''.join(response_content)))
|
| 579 |
-
|
| 580 |
@app.post("/images/generations")
|
| 581 |
async def create_image(payload: ImageGenerationPayload, authenticated: bool = Depends(verify_api_key)):
|
| 582 |
if not server_status:
|
|
|
|
| 501 |
|
| 502 |
env_vars = get_env_vars()
|
| 503 |
|
| 504 |
+
target_url_path = "/v1/chat/completions" # Default path
|
| 505 |
+
|
| 506 |
if model_to_use in mistral_models:
|
| 507 |
endpoint = env_vars['mistral_api']
|
| 508 |
custom_headers = {
|
|
|
|
| 526 |
custom_headers = {
|
| 527 |
"Authorization": f"Bearer {env_vars['gemini_key']}"
|
| 528 |
}
|
| 529 |
+
target_url_path = "/chat/completions" # Use /chat/completions for Gemini
|
| 530 |
else:
|
| 531 |
endpoint = env_vars['secret_api_endpoint']
|
| 532 |
custom_headers = {
|
|
|
|
| 535 |
"Referer": header_url
|
| 536 |
}
|
| 537 |
|
| 538 |
+
print(f"Using endpoint: {endpoint} with path: {target_url_path} for model: {model_to_use}")
|
| 539 |
|
| 540 |
async def real_time_stream_generator():
|
| 541 |
try:
|
| 542 |
async with httpx.AsyncClient(timeout=60.0) as client:
|
| 543 |
+
async with client.stream("POST", f"{endpoint}{target_url_path}", json=payload_dict, headers=custom_headers) as response:
|
| 544 |
if response.status_code >= 400:
|
| 545 |
error_messages = {
|
| 546 |
422: "Unprocessable entity. Check your payload.",
|
|
|
|
| 579 |
async for chunk in real_time_stream_generator():
|
| 580 |
response_content.append(chunk)
|
| 581 |
return JSONResponse(content=json.loads(''.join(response_content)))
|
|
|
|
| 582 |
@app.post("/images/generations")
|
| 583 |
async def create_image(payload: ImageGenerationPayload, authenticated: bool = Depends(verify_api_key)):
|
| 584 |
if not server_status:
|