File size: 4,552 Bytes
96e8d35
 
 
 
 
 
 
 
 
 
 
 
 
8439b0d
96e8d35
 
 
 
 
 
8439b0d
96e8d35
 
 
 
 
 
 
 
 
 
 
 
 
8439b0d
96e8d35
 
 
8439b0d
96e8d35
bc62be8
 
 
 
 
 
 
 
 
8439b0d
96e8d35
 
8439b0d
96e8d35
 
8439b0d
96e8d35
 
 
 
 
 
 
 
 
 
8439b0d
96e8d35
 
 
8439b0d
96e8d35
 
 
 
 
 
8439b0d
96e8d35
 
 
 
8439b0d
96e8d35
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
8439b0d
96e8d35
 
 
8439b0d
 
96e8d35
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
import torch
from transformers import AutoTokenizer, AutoModelForCausalLM, BitsAndBytesConfig
from peft import PeftModel
from fastapi import FastAPI, Request, HTTPException
from pydantic import BaseModel, Field
from slowapi import Limiter
from slowapi.util import get_remote_address
from slowapi.errors import RateLimitExceeded
from fastapi.responses import JSONResponse
import uvicorn
import time
from collections import defaultdict
import asyncio

# MODEL CONFIG
base_model = "cognitivecomputations/dolphin-2.9.3-mistral-nemo-12b"
adapter_repo = "santacl/septicspo"
tokenizer = AutoTokenizer.from_pretrained(base_model)
if tokenizer.pad_token is None:
    tokenizer.pad_token = tokenizer.eos_token

bnb_config = BitsAndBytesConfig(
    load_in_4bit=True,
    bnb_4bit_compute_dtype=torch.float16,
    bnb_4bit_use_double_quant=True,
    bnb_4bit_quant_type="nf4"
)

print("Loading base model.")
base = AutoModelForCausalLM.from_pretrained(
    base_model,
    quantization_config=bnb_config,
    device_map="auto"
)

print("Loading LoRA adapter.")
model = PeftModel.from_pretrained(base, adapter_repo, subfolder="checkpoint-240")
print("Model ready")

# RATE LIMITER
app = FastAPI(title="PROMETHEUS")

app.add_middleware(
    CORSMiddleware,
    allow_origins=["*"],  
    allow_credentials=True,
    allow_methods=["*"],
    allow_headers=["*"],
)

limiter = Limiter(key_func=get_remote_address)
app.state.limiter = limiter

request_history = defaultdict(list)
HISTORY_CLEANUP_INTERVAL = 300  

async def cleanup_request_history():
    """Background task to clean up old request history"""
    while True:
        await asyncio.sleep(HISTORY_CLEANUP_INTERVAL)
        now = time.time()
        window_start = now - 60
        for user_id in list(request_history.keys()):
            request_history[user_id] = [t for t in request_history[user_id] if t > window_start]
            if not request_history[user_id]:
                del request_history[user_id]

@app.on_event("startup")
async def startup_event():
    asyncio.create_task(cleanup_request_history())

@app.exception_handler(RateLimitExceeded)
async def rate_limit_handler(request: Request, exc: RateLimitExceeded):
    return JSONResponse(
        status_code=429,
        content={"detail": "Rate limit exceeded (10 requests/min). Please wait a bit."},
    )

# SCHEMA
class ChatRequest(BaseModel):
    message: str = Field(..., min_length=1, max_length=2000)
    user_id: str = Field(default="anonymous")

# CHAT ENDPOINT
@app.post("/chat")
@limiter.limit("10/minute")
async def chat(req: ChatRequest, request: Request):
    user_id = req.user_id
    message = req.message.strip()
    
    if not message:
        raise HTTPException(status_code=400, detail="Message cannot be empty")
    
    # Additional soft rate limit (20 requests/minute)
    now = time.time()
    window_start = now - 60
    user_reqs = request_history[user_id]
    user_reqs = [t for t in user_reqs if t > window_start]
    user_reqs.append(now)
    request_history[user_id] = user_reqs
    
    if len(user_reqs) > 20:
        return JSONResponse(
            status_code=429,
            content={"response": "You're sending too many requests β€” please wait a bit."}
        )
    
    try:
        system_prompt = "You are Dolphin, a logical, calm, and grounded conversational AI."
        prompt_text = f"<|im_start|>system\n{system_prompt}<|im_end|>\n"
        prompt_text += f"<|im_start|>user\n{message}<|im_end|>\n<|im_start|>assistant\n"
        
        inputs = tokenizer(prompt_text, return_tensors="pt").to("cuda")
        
        output = model.generate(
            **inputs,
            max_new_tokens=512,
            temperature=0.7,
            top_p=0.9,
            repetition_penalty=1.1,
            do_sample=True,
            pad_token_id=tokenizer.eos_token_id,
        )
        
        response = tokenizer.decode(output[0], skip_special_tokens=False)
        response = response.split("<|im_start|>assistant")[-1].replace("<|im_end|>", "").strip()
        
        return {"response": response}
    
    except torch.cuda.OutOfMemoryError:
        raise HTTPException(status_code=503, detail="Server is overloaded. Please try again later.")
    except Exception as e:
        print(f"Error generating response: {str(e)}")
        raise HTTPException(status_code=500, detail="Failed to generate response")

@app.get("/health")
async def health_check():
    return {"status": "healthy", "model": "loaded"}

if __name__ == "__main__":
    uvicorn.run(app, host="0.0.0.0", port=7860)