Devils_child / app.py
santacl's picture
Update app.py
bc62be8 verified
raw
history blame
4.55 kB
import torch
from transformers import AutoTokenizer, AutoModelForCausalLM, BitsAndBytesConfig
from peft import PeftModel
from fastapi import FastAPI, Request, HTTPException
from pydantic import BaseModel, Field
from slowapi import Limiter
from slowapi.util import get_remote_address
from slowapi.errors import RateLimitExceeded
from fastapi.responses import JSONResponse
import uvicorn
import time
from collections import defaultdict
import asyncio
# MODEL CONFIG
base_model = "cognitivecomputations/dolphin-2.9.3-mistral-nemo-12b"
adapter_repo = "santacl/septicspo"
tokenizer = AutoTokenizer.from_pretrained(base_model)
if tokenizer.pad_token is None:
tokenizer.pad_token = tokenizer.eos_token
bnb_config = BitsAndBytesConfig(
load_in_4bit=True,
bnb_4bit_compute_dtype=torch.float16,
bnb_4bit_use_double_quant=True,
bnb_4bit_quant_type="nf4"
)
print("Loading base model.")
base = AutoModelForCausalLM.from_pretrained(
base_model,
quantization_config=bnb_config,
device_map="auto"
)
print("Loading LoRA adapter.")
model = PeftModel.from_pretrained(base, adapter_repo, subfolder="checkpoint-240")
print("Model ready")
# RATE LIMITER
app = FastAPI(title="PROMETHEUS")
app.add_middleware(
CORSMiddleware,
allow_origins=["*"],
allow_credentials=True,
allow_methods=["*"],
allow_headers=["*"],
)
limiter = Limiter(key_func=get_remote_address)
app.state.limiter = limiter
request_history = defaultdict(list)
HISTORY_CLEANUP_INTERVAL = 300
async def cleanup_request_history():
"""Background task to clean up old request history"""
while True:
await asyncio.sleep(HISTORY_CLEANUP_INTERVAL)
now = time.time()
window_start = now - 60
for user_id in list(request_history.keys()):
request_history[user_id] = [t for t in request_history[user_id] if t > window_start]
if not request_history[user_id]:
del request_history[user_id]
@app.on_event("startup")
async def startup_event():
asyncio.create_task(cleanup_request_history())
@app.exception_handler(RateLimitExceeded)
async def rate_limit_handler(request: Request, exc: RateLimitExceeded):
return JSONResponse(
status_code=429,
content={"detail": "Rate limit exceeded (10 requests/min). Please wait a bit."},
)
# SCHEMA
class ChatRequest(BaseModel):
message: str = Field(..., min_length=1, max_length=2000)
user_id: str = Field(default="anonymous")
# CHAT ENDPOINT
@app.post("/chat")
@limiter.limit("10/minute")
async def chat(req: ChatRequest, request: Request):
user_id = req.user_id
message = req.message.strip()
if not message:
raise HTTPException(status_code=400, detail="Message cannot be empty")
# Additional soft rate limit (20 requests/minute)
now = time.time()
window_start = now - 60
user_reqs = request_history[user_id]
user_reqs = [t for t in user_reqs if t > window_start]
user_reqs.append(now)
request_history[user_id] = user_reqs
if len(user_reqs) > 20:
return JSONResponse(
status_code=429,
content={"response": "You're sending too many requests β€” please wait a bit."}
)
try:
system_prompt = "You are Dolphin, a logical, calm, and grounded conversational AI."
prompt_text = f"<|im_start|>system\n{system_prompt}<|im_end|>\n"
prompt_text += f"<|im_start|>user\n{message}<|im_end|>\n<|im_start|>assistant\n"
inputs = tokenizer(prompt_text, return_tensors="pt").to("cuda")
output = model.generate(
**inputs,
max_new_tokens=512,
temperature=0.7,
top_p=0.9,
repetition_penalty=1.1,
do_sample=True,
pad_token_id=tokenizer.eos_token_id,
)
response = tokenizer.decode(output[0], skip_special_tokens=False)
response = response.split("<|im_start|>assistant")[-1].replace("<|im_end|>", "").strip()
return {"response": response}
except torch.cuda.OutOfMemoryError:
raise HTTPException(status_code=503, detail="Server is overloaded. Please try again later.")
except Exception as e:
print(f"Error generating response: {str(e)}")
raise HTTPException(status_code=500, detail="Failed to generate response")
@app.get("/health")
async def health_check():
return {"status": "healthy", "model": "loaded"}
if __name__ == "__main__":
uvicorn.run(app, host="0.0.0.0", port=7860)