Spaces:
Running
Running
Update main.py
Browse files
main.py
CHANGED
|
@@ -17,7 +17,44 @@ import time
|
|
| 17 |
from usage_tracker import UsageTracker
|
| 18 |
from starlette.middleware.base import BaseHTTPMiddleware
|
| 19 |
from collections import defaultdict
|
|
|
|
|
|
|
|
|
|
|
|
|
| 20 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 21 |
class RateLimitMiddleware(BaseHTTPMiddleware):
|
| 22 |
def __init__(self, app, requests_per_second: int = 2):
|
| 23 |
super().__init__(app)
|
|
@@ -62,7 +99,6 @@ app = FastAPI()
|
|
| 62 |
app.add_middleware(RateLimitMiddleware, requests_per_second=2)
|
| 63 |
|
| 64 |
# Get API keys and secret endpoint from environment variables
|
| 65 |
-
api_keys_str = os.getenv('API_KEYS') #deprecated -_-
|
| 66 |
valid_api_keys = api_keys_str.split(',') if api_keys_str else []
|
| 67 |
secret_api_endpoint = os.getenv('SECRET_API_ENDPOINT')
|
| 68 |
secret_api_endpoint_2 = os.getenv('SECRET_API_ENDPOINT_2')
|
|
@@ -75,7 +111,7 @@ if not secret_api_endpoint or not secret_api_endpoint_2 or not secret_api_endpoi
|
|
| 75 |
raise HTTPException(status_code=500, detail="API endpoint(s) are not configured in environment variables.")
|
| 76 |
|
| 77 |
# Define models that should use the secondary endpoint
|
| 78 |
-
alternate_models = {"gpt-4o-mini", "claude-3-haiku", "llama-3.1-70b", "mixtral-8x7b"}
|
| 79 |
|
| 80 |
available_model_ids = []
|
| 81 |
class Payload(BaseModel):
|
|
@@ -154,7 +190,7 @@ async def ping():
|
|
| 154 |
return {"message": "pong", "response_time": f"{response_time:.6f} seconds"}
|
| 155 |
|
| 156 |
@app.get("/searchgpt")
|
| 157 |
-
async def search_gpt(q: str, stream: Optional[bool] = False, systemprompt: Optional[str] = None):
|
| 158 |
if not q:
|
| 159 |
raise HTTPException(status_code=400, detail="Query parameter 'q' is required")
|
| 160 |
usage_tracker.record_request(endpoint="/searchgpt")
|
|
@@ -191,12 +227,12 @@ async def get_models():
|
|
| 191 |
raise HTTPException(status_code=500, detail="Error decoding models.json")
|
| 192 |
@app.get("api/v1/models")
|
| 193 |
@app.get("/models")
|
| 194 |
-
async def fetch_models():
|
| 195 |
return await get_models()
|
| 196 |
server_status = True
|
| 197 |
@app.post("/chat/completions")
|
| 198 |
@app.post("api/v1/chat/completions")
|
| 199 |
-
async def get_completion(payload: Payload, request: Request):
|
| 200 |
# Check server status
|
| 201 |
|
| 202 |
|
|
@@ -216,7 +252,7 @@ async def get_completion(payload: Payload, request: Request):
|
|
| 216 |
payload_dict["model"] = model_to_use
|
| 217 |
# payload_dict["stream"] = payload_dict.get("stream", False)
|
| 218 |
# Select the appropriate endpoint
|
| 219 |
-
endpoint =
|
| 220 |
|
| 221 |
# Current time and IP logging
|
| 222 |
current_time = (datetime.datetime.utcnow() + datetime.timedelta(hours=5, minutes=30)).strftime("%Y-%m-%d %I:%M:%S %p")
|
|
@@ -286,6 +322,7 @@ async def generate_image(
|
|
| 286 |
private: Optional[bool] = None,
|
| 287 |
enhance: Optional[bool] = None,
|
| 288 |
request: Request = None, # Access raw POST data
|
|
|
|
| 289 |
):
|
| 290 |
"""
|
| 291 |
Generate an image using the Image Generation API.
|
|
|
|
| 17 |
from usage_tracker import UsageTracker
|
| 18 |
from starlette.middleware.base import BaseHTTPMiddleware
|
| 19 |
from collections import defaultdict
|
| 20 |
+
from fastapi import Security #new
|
| 21 |
+
from fastapi.security import APIKeyHeader
|
| 22 |
+
from starlette.exceptions import HTTPException
|
| 23 |
+
from starlette.status import HTTP_403_FORBIDDEN
|
| 24 |
|
| 25 |
+
# API key header scheme
|
| 26 |
+
api_key_header = APIKeyHeader(name="Authorization", auto_error=False)
|
| 27 |
+
|
| 28 |
+
# Function to validate API key
|
| 29 |
+
async def verify_api_key(api_key: str = Security(api_key_header)) -> bool:
|
| 30 |
+
if not api_key:
|
| 31 |
+
raise HTTPException(
|
| 32 |
+
status_code=HTTP_403_FORBIDDEN,
|
| 33 |
+
detail="No API key provided"
|
| 34 |
+
)
|
| 35 |
+
|
| 36 |
+
# Clean the API key by removing 'Bearer ' if present
|
| 37 |
+
if api_key.startswith('Bearer '):
|
| 38 |
+
api_key = api_key[7:] # Remove 'Bearer ' prefix
|
| 39 |
+
|
| 40 |
+
# Get API keys from environment
|
| 41 |
+
api_keys_str = os.getenv('API_KEYS')
|
| 42 |
+
if not api_keys_str:
|
| 43 |
+
raise HTTPException(
|
| 44 |
+
status_code=HTTP_403_FORBIDDEN,
|
| 45 |
+
detail="API keys not configured on server"
|
| 46 |
+
)
|
| 47 |
+
|
| 48 |
+
valid_api_keys = api_keys_str.split(',')
|
| 49 |
+
|
| 50 |
+
# Check if the provided key is valid
|
| 51 |
+
if api_key not in valid_api_keys:
|
| 52 |
+
raise HTTPException(
|
| 53 |
+
status_code=HTTP_403_FORBIDDEN,
|
| 54 |
+
detail="Invalid API key"
|
| 55 |
+
)
|
| 56 |
+
|
| 57 |
+
return True
|
| 58 |
class RateLimitMiddleware(BaseHTTPMiddleware):
|
| 59 |
def __init__(self, app, requests_per_second: int = 2):
|
| 60 |
super().__init__(app)
|
|
|
|
| 99 |
app.add_middleware(RateLimitMiddleware, requests_per_second=2)
|
| 100 |
|
| 101 |
# Get API keys and secret endpoint from environment variables
|
|
|
|
| 102 |
valid_api_keys = api_keys_str.split(',') if api_keys_str else []
|
| 103 |
secret_api_endpoint = os.getenv('SECRET_API_ENDPOINT')
|
| 104 |
secret_api_endpoint_2 = os.getenv('SECRET_API_ENDPOINT_2')
|
|
|
|
| 111 |
raise HTTPException(status_code=500, detail="API endpoint(s) are not configured in environment variables.")
|
| 112 |
|
| 113 |
# Define models that should use the secondary endpoint
|
| 114 |
+
# alternate_models = {"gpt-4o-mini", "claude-3-haiku", "llama-3.1-70b", "mixtral-8x7b"}
|
| 115 |
|
| 116 |
available_model_ids = []
|
| 117 |
class Payload(BaseModel):
|
|
|
|
| 190 |
return {"message": "pong", "response_time": f"{response_time:.6f} seconds"}
|
| 191 |
|
| 192 |
@app.get("/searchgpt")
|
| 193 |
+
async def search_gpt(q: str, stream: Optional[bool] = False, systemprompt: Optional[str] = None,authenticated: bool = Depends(verify_api_key)):
|
| 194 |
if not q:
|
| 195 |
raise HTTPException(status_code=400, detail="Query parameter 'q' is required")
|
| 196 |
usage_tracker.record_request(endpoint="/searchgpt")
|
|
|
|
| 227 |
raise HTTPException(status_code=500, detail="Error decoding models.json")
|
| 228 |
@app.get("api/v1/models")
|
| 229 |
@app.get("/models")
|
| 230 |
+
async def fetch_models(authenticated: bool = Depends(verify_api_key)):
|
| 231 |
return await get_models()
|
| 232 |
server_status = True
|
| 233 |
@app.post("/chat/completions")
|
| 234 |
@app.post("api/v1/chat/completions")
|
| 235 |
+
async def get_completion(payload: Payload, request: Request,authenticated: bool = Depends(verify_api_key)):
|
| 236 |
# Check server status
|
| 237 |
|
| 238 |
|
|
|
|
| 252 |
payload_dict["model"] = model_to_use
|
| 253 |
# payload_dict["stream"] = payload_dict.get("stream", False)
|
| 254 |
# Select the appropriate endpoint
|
| 255 |
+
endpoint = secret_api_endpoint
|
| 256 |
|
| 257 |
# Current time and IP logging
|
| 258 |
current_time = (datetime.datetime.utcnow() + datetime.timedelta(hours=5, minutes=30)).strftime("%Y-%m-%d %I:%M:%S %p")
|
|
|
|
| 322 |
private: Optional[bool] = None,
|
| 323 |
enhance: Optional[bool] = None,
|
| 324 |
request: Request = None, # Access raw POST data
|
| 325 |
+
authenticated: bool = Depends(verify_api_key)
|
| 326 |
):
|
| 327 |
"""
|
| 328 |
Generate an image using the Image Generation API.
|