Spaces:
Runtime error
Runtime error
| import torch | |
| from transformers import AutoTokenizer, AutoModelForCausalLM, BitsAndBytesConfig | |
| from peft import PeftModel | |
| from fastapi import FastAPI, Request, HTTPException | |
| from pydantic import BaseModel, Field | |
| from slowapi import Limiter | |
| from slowapi.util import get_remote_address | |
| from slowapi.errors import RateLimitExceeded | |
| from fastapi.responses import JSONResponse | |
| import uvicorn | |
| import time | |
| from collections import defaultdict | |
| import asyncio | |
| # MODEL CONFIG | |
| base_model = "cognitivecomputations/dolphin-2.9.3-mistral-nemo-12b" | |
| adapter_repo = "santacl/septicspo" | |
| tokenizer = AutoTokenizer.from_pretrained(base_model) | |
| if tokenizer.pad_token is None: | |
| tokenizer.pad_token = tokenizer.eos_token | |
| bnb_config = BitsAndBytesConfig( | |
| load_in_4bit=True, | |
| bnb_4bit_compute_dtype=torch.float16, | |
| bnb_4bit_use_double_quant=True, | |
| bnb_4bit_quant_type="nf4" | |
| ) | |
| print("Loading base model.") | |
| base = AutoModelForCausalLM.from_pretrained( | |
| base_model, | |
| quantization_config=bnb_config, | |
| device_map="auto" | |
| ) | |
| print("Loading LoRA adapter.") | |
| model = PeftModel.from_pretrained(base, adapter_repo, subfolder="checkpoint-240") | |
| print("Model ready") | |
| # RATE LIMITER | |
| app = FastAPI(title="PROMETHEUS") | |
| app.add_middleware( | |
| CORSMiddleware, | |
| allow_origins=["*"], | |
| allow_credentials=True, | |
| allow_methods=["*"], | |
| allow_headers=["*"], | |
| ) | |
| limiter = Limiter(key_func=get_remote_address) | |
| app.state.limiter = limiter | |
| request_history = defaultdict(list) | |
| HISTORY_CLEANUP_INTERVAL = 300 | |
| async def cleanup_request_history(): | |
| """Background task to clean up old request history""" | |
| while True: | |
| await asyncio.sleep(HISTORY_CLEANUP_INTERVAL) | |
| now = time.time() | |
| window_start = now - 60 | |
| for user_id in list(request_history.keys()): | |
| request_history[user_id] = [t for t in request_history[user_id] if t > window_start] | |
| if not request_history[user_id]: | |
| del request_history[user_id] | |
| async def startup_event(): | |
| asyncio.create_task(cleanup_request_history()) | |
| async def rate_limit_handler(request: Request, exc: RateLimitExceeded): | |
| return JSONResponse( | |
| status_code=429, | |
| content={"detail": "Rate limit exceeded (10 requests/min). Please wait a bit."}, | |
| ) | |
| # SCHEMA | |
| class ChatRequest(BaseModel): | |
| message: str = Field(..., min_length=1, max_length=2000) | |
| user_id: str = Field(default="anonymous") | |
| # CHAT ENDPOINT | |
| async def chat(req: ChatRequest, request: Request): | |
| user_id = req.user_id | |
| message = req.message.strip() | |
| if not message: | |
| raise HTTPException(status_code=400, detail="Message cannot be empty") | |
| # Additional soft rate limit (20 requests/minute) | |
| now = time.time() | |
| window_start = now - 60 | |
| user_reqs = request_history[user_id] | |
| user_reqs = [t for t in user_reqs if t > window_start] | |
| user_reqs.append(now) | |
| request_history[user_id] = user_reqs | |
| if len(user_reqs) > 20: | |
| return JSONResponse( | |
| status_code=429, | |
| content={"response": "You're sending too many requests β please wait a bit."} | |
| ) | |
| try: | |
| system_prompt = "You are Dolphin, a logical, calm, and grounded conversational AI." | |
| prompt_text = f"<|im_start|>system\n{system_prompt}<|im_end|>\n" | |
| prompt_text += f"<|im_start|>user\n{message}<|im_end|>\n<|im_start|>assistant\n" | |
| inputs = tokenizer(prompt_text, return_tensors="pt").to("cuda") | |
| output = model.generate( | |
| **inputs, | |
| max_new_tokens=512, | |
| temperature=0.7, | |
| top_p=0.9, | |
| repetition_penalty=1.1, | |
| do_sample=True, | |
| pad_token_id=tokenizer.eos_token_id, | |
| ) | |
| response = tokenizer.decode(output[0], skip_special_tokens=False) | |
| response = response.split("<|im_start|>assistant")[-1].replace("<|im_end|>", "").strip() | |
| return {"response": response} | |
| except torch.cuda.OutOfMemoryError: | |
| raise HTTPException(status_code=503, detail="Server is overloaded. Please try again later.") | |
| except Exception as e: | |
| print(f"Error generating response: {str(e)}") | |
| raise HTTPException(status_code=500, detail="Failed to generate response") | |
| async def health_check(): | |
| return {"status": "healthy", "model": "loaded"} | |
| if __name__ == "__main__": | |
| uvicorn.run(app, host="0.0.0.0", port=7860) |