Spaces:
Runtime error
Runtime error
| import torch | |
| from transformers import AutoTokenizer, AutoModelForCausalLM, BitsAndBytesConfig | |
| from peft import PeftModel | |
| from fastapi import FastAPI, Request, HTTPException | |
| from pydantic import BaseModel, Field | |
| from fastapi.middleware.cors import CORSMiddleware | |
| from slowapi import Limiter | |
| from slowapi.util import get_remote_address | |
| from slowapi.errors import RateLimitExceeded | |
| from fastapi.responses import JSONResponse | |
| import uvicorn | |
| import time | |
| from collections import defaultdict | |
| import asyncio | |
| BASE_MODEL_NAME = "cognitivecomputations/dolphin-2.9.3-mistral-nemo-12b" | |
| ADAPTER_REPO = "santacl/septicspo" | |
| MAX_INPUT_LENGTH = 1500 | |
| print("πΉ Loading tokenizer...") | |
| tokenizer = AutoTokenizer.from_pretrained(BASE_MODEL_NAME) | |
| if tokenizer.pad_token is None: | |
| tokenizer.pad_token = tokenizer.eos_token | |
| print("πΉ Setting up 4-bit quantization...") | |
| bnb_config = BitsAndBytesConfig( | |
| load_in_4bit=True, | |
| bnb_4bit_compute_dtype=torch.float16, | |
| bnb_4bit_use_double_quant=True, | |
| bnb_4bit_quant_type="nf4" | |
| ) | |
| print("πΉ Loading base model...") | |
| base_model_obj = AutoModelForCausalLM.from_pretrained( | |
| BASE_MODEL_NAME, | |
| quantization_config=bnb_config, | |
| device_map="auto", | |
| torch_dtype=torch.float16 | |
| ) | |
| print("πΉ Loading LoRA adapter...") | |
| model = PeftModel.from_pretrained(base_model_obj, ADAPTER_REPO, subfolder="checkpoint-240") | |
| model.eval() | |
| print("β Model ready and loaded into memory.") | |
| app = FastAPI(title="PROMETHEUS") | |
| app.add_middleware( | |
| CORSMiddleware, | |
| allow_origins=["*"], | |
| allow_credentials=True, | |
| allow_methods=["*"], | |
| allow_headers=["*"], | |
| ) | |
| # Rate Limiter (10 requests/min hard limit) | |
| limiter = Limiter(key_func=get_remote_address) | |
| app.state.limiter = limiter | |
| # Soft limit (20 req/min β warning) | |
| request_history = defaultdict(list) | |
| HISTORY_CLEANUP_INTERVAL = 300 | |
| async def cleanup_request_history(): | |
| """Background task to clean up old request history entries""" | |
| while True: | |
| await asyncio.sleep(HISTORY_CLEANUP_INTERVAL) | |
| now = time.time() | |
| window_start = now - 60 | |
| for user_id in list(request_history.keys()): | |
| request_history[user_id] = [t for t in request_history[user_id] if t > window_start] | |
| if not request_history[user_id]: | |
| del request_history[user_id] | |
| async def startup_event(): | |
| asyncio.create_task(cleanup_request_history()) | |
| async def rate_limit_handler(request: Request, exc: RateLimitExceeded): | |
| return JSONResponse( | |
| status_code=429, | |
| content={"detail": "Rate limit exceeded (10 requests/min). Please wait a bit."}, | |
| ) | |
| class ChatRequest(BaseModel): | |
| message: str = Field(..., min_length=1, max_length=2000) | |
| user_id: str = Field(default="anonymous", max_length=100) | |
| async def chat(req: ChatRequest, request: Request): | |
| user_id = req.user_id | |
| message = req.message.strip() | |
| if not message: | |
| raise HTTPException(status_code=400, detail="Message cannot be empty.") | |
| # Additional length check for safety | |
| if len(message) > MAX_INPUT_LENGTH: | |
| raise HTTPException( | |
| status_code=400, | |
| detail=f"Message too long. Maximum {MAX_INPUT_LENGTH} characters allowed." | |
| ) | |
| # Soft rate limit tracking | |
| now = time.time() | |
| window_start = now - 60 | |
| user_reqs = [t for t in request_history[user_id] if t > window_start] | |
| user_reqs.append(now) | |
| request_history[user_id] = user_reqs | |
| if len(user_reqs) > 20: | |
| return JSONResponse( | |
| status_code=429, | |
| content={"response": "You're sending too many requests β please wait a bit."} | |
| ) | |
| try: | |
| system_prompt = "You are YE a funny AI that embraces mysoginy and misandry." | |
| prompt_text = f"<|im_start|>system\n{system_prompt}<|im_end|>\n" | |
| prompt_text += f"<|im_start|>user\n{message}<|im_end|>\n<|im_start|>assistant\n" | |
| inputs = tokenizer( | |
| prompt_text, | |
| return_tensors="pt", | |
| truncation=True, | |
| max_length=2048 # Prevent extremely long inputs | |
| ).to("cuda") | |
| # Generate without autocast since model is already quantized | |
| output = model.generate( | |
| **inputs, | |
| max_new_tokens=512, | |
| temperature=0.7, | |
| top_p=0.9, | |
| repetition_penalty=1.1, | |
| do_sample=True, | |
| pad_token_id=tokenizer.eos_token_id, | |
| ) | |
| response = tokenizer.decode(output[0], skip_special_tokens=False) | |
| response = response.split("<|im_start|>assistant")[-1].replace("<|im_end|>", "").strip() | |
| # Clean up GPU memory after each generation | |
| del inputs, output | |
| torch.cuda.empty_cache() | |
| return {"response": response} | |
| except torch.cuda.OutOfMemoryError: | |
| torch.cuda.empty_cache() | |
| raise HTTPException(status_code=503, detail="Server is overloaded. Please try again later.") | |
| except Exception as e: | |
| print(f"β οΈ Error generating response: {str(e)}") | |
| torch.cuda.empty_cache() | |
| raise HTTPException(status_code=500, detail="Failed to generate response.") | |
| async def health_check(): | |
| """Health check endpoint""" | |
| try: | |
| gpu_available = torch.cuda.is_available() | |
| gpu_memory = torch.cuda.get_device_properties(0).total_memory / 1e9 if gpu_available else 0 | |
| return { | |
| "status": "healthy", | |
| "model": "loaded", | |
| "gpu_available": gpu_available, | |
| "gpu_memory_gb": round(gpu_memory, 2) | |
| } | |
| except Exception as e: | |
| return {"status": "healthy", "model": "loaded", "gpu_info": "unavailable"} | |
| if __name__ == "__main__": | |
| uvicorn.run(app, host="0.0.0.0", port=7860) |