santacl commited on
Commit
574e890
Β·
verified Β·
1 Parent(s): 8aaa43e

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +77 -35
app.py CHANGED
@@ -3,6 +3,7 @@ from transformers import AutoTokenizer, AutoModelForCausalLM, BitsAndBytesConfig
3
  from peft import PeftModel
4
  from fastapi import FastAPI, Request, HTTPException
5
  from pydantic import BaseModel, Field
 
6
  from slowapi import Limiter
7
  from slowapi.util import get_remote_address
8
  from slowapi.errors import RateLimitExceeded
@@ -12,13 +13,17 @@ import time
12
  from collections import defaultdict
13
  import asyncio
14
 
15
- # MODEL CONFIG
16
- base_model = "cognitivecomputations/dolphin-2.9.3-mistral-nemo-12b"
17
- adapter_repo = "santacl/septicspo"
18
- tokenizer = AutoTokenizer.from_pretrained(base_model)
 
 
 
19
  if tokenizer.pad_token is None:
20
  tokenizer.pad_token = tokenizer.eos_token
21
 
 
22
  bnb_config = BitsAndBytesConfig(
23
  load_in_4bit=True,
24
  bnb_4bit_compute_dtype=torch.float16,
@@ -26,36 +31,41 @@ bnb_config = BitsAndBytesConfig(
26
  bnb_4bit_quant_type="nf4"
27
  )
28
 
29
- print("Loading base model.")
30
- base = AutoModelForCausalLM.from_pretrained(
31
- base_model,
32
  quantization_config=bnb_config,
33
- device_map="auto"
 
34
  )
35
 
36
- print("Loading LoRA adapter.")
37
- model = PeftModel.from_pretrained(base, adapter_repo, subfolder="checkpoint-240")
38
- print("Model ready")
 
 
39
 
40
- # RATE LIMITER
41
  app = FastAPI(title="PROMETHEUS")
42
 
43
  app.add_middleware(
44
  CORSMiddleware,
45
- allow_origins=["*"],
46
  allow_credentials=True,
47
  allow_methods=["*"],
48
  allow_headers=["*"],
49
  )
50
 
 
51
  limiter = Limiter(key_func=get_remote_address)
52
  app.state.limiter = limiter
53
 
 
54
  request_history = defaultdict(list)
55
- HISTORY_CLEANUP_INTERVAL = 300
 
56
 
57
  async def cleanup_request_history():
58
- """Background task to clean up old request history"""
59
  while True:
60
  await asyncio.sleep(HISTORY_CLEANUP_INTERVAL)
61
  now = time.time()
@@ -69,6 +79,7 @@ async def cleanup_request_history():
69
  async def startup_event():
70
  asyncio.create_task(cleanup_request_history())
71
 
 
72
  @app.exception_handler(RateLimitExceeded)
73
  async def rate_limit_handler(request: Request, exc: RateLimitExceeded):
74
  return JSONResponse(
@@ -76,42 +87,54 @@ async def rate_limit_handler(request: Request, exc: RateLimitExceeded):
76
  content={"detail": "Rate limit exceeded (10 requests/min). Please wait a bit."},
77
  )
78
 
79
- # SCHEMA
80
  class ChatRequest(BaseModel):
81
  message: str = Field(..., min_length=1, max_length=2000)
82
- user_id: str = Field(default="anonymous")
 
83
 
84
- # CHAT ENDPOINT
85
  @app.post("/chat")
86
  @limiter.limit("10/minute")
87
  async def chat(req: ChatRequest, request: Request):
88
  user_id = req.user_id
89
  message = req.message.strip()
90
-
91
  if not message:
92
- raise HTTPException(status_code=400, detail="Message cannot be empty")
93
 
94
- # Additional soft rate limit (20 requests/minute)
 
 
 
 
 
 
 
95
  now = time.time()
96
  window_start = now - 60
97
- user_reqs = request_history[user_id]
98
- user_reqs = [t for t in user_reqs if t > window_start]
99
  user_reqs.append(now)
100
  request_history[user_id] = user_reqs
101
-
102
  if len(user_reqs) > 20:
103
  return JSONResponse(
104
  status_code=429,
105
  content={"response": "You're sending too many requests β€” please wait a bit."}
106
  )
107
-
108
  try:
109
- system_prompt = "You are Dolphin, a logical, calm, and grounded conversational AI."
110
  prompt_text = f"<|im_start|>system\n{system_prompt}<|im_end|>\n"
111
  prompt_text += f"<|im_start|>user\n{message}<|im_end|>\n<|im_start|>assistant\n"
112
-
113
- inputs = tokenizer(prompt_text, return_tensors="pt").to("cuda")
114
-
 
 
 
 
 
 
115
  output = model.generate(
116
  **inputs,
117
  max_new_tokens=512,
@@ -121,21 +144,40 @@ async def chat(req: ChatRequest, request: Request):
121
  do_sample=True,
122
  pad_token_id=tokenizer.eos_token_id,
123
  )
124
-
125
  response = tokenizer.decode(output[0], skip_special_tokens=False)
126
  response = response.split("<|im_start|>assistant")[-1].replace("<|im_end|>", "").strip()
127
-
 
 
 
 
128
  return {"response": response}
129
-
130
  except torch.cuda.OutOfMemoryError:
 
131
  raise HTTPException(status_code=503, detail="Server is overloaded. Please try again later.")
132
  except Exception as e:
133
- print(f"Error generating response: {str(e)}")
134
- raise HTTPException(status_code=500, detail="Failed to generate response")
 
 
135
 
136
  @app.get("/health")
137
  async def health_check():
138
- return {"status": "healthy", "model": "loaded"}
 
 
 
 
 
 
 
 
 
 
 
 
139
 
140
  if __name__ == "__main__":
141
  uvicorn.run(app, host="0.0.0.0", port=7860)
 
3
  from peft import PeftModel
4
  from fastapi import FastAPI, Request, HTTPException
5
  from pydantic import BaseModel, Field
6
+ from fastapi.middleware.cors import CORSMiddleware
7
  from slowapi import Limiter
8
  from slowapi.util import get_remote_address
9
  from slowapi.errors import RateLimitExceeded
 
13
  from collections import defaultdict
14
  import asyncio
15
 
16
+
17
+ BASE_MODEL_NAME = "cognitivecomputations/dolphin-2.9.3-mistral-nemo-12b"
18
+ ADAPTER_REPO = "santacl/septicspo"
19
+ MAX_INPUT_LENGTH = 1500
20
+
21
+ print("πŸ”Ή Loading tokenizer...")
22
+ tokenizer = AutoTokenizer.from_pretrained(BASE_MODEL_NAME)
23
  if tokenizer.pad_token is None:
24
  tokenizer.pad_token = tokenizer.eos_token
25
 
26
+ print("πŸ”Ή Setting up 4-bit quantization...")
27
  bnb_config = BitsAndBytesConfig(
28
  load_in_4bit=True,
29
  bnb_4bit_compute_dtype=torch.float16,
 
31
  bnb_4bit_quant_type="nf4"
32
  )
33
 
34
+ print("πŸ”Ή Loading base model...")
35
+ base_model_obj = AutoModelForCausalLM.from_pretrained(
36
+ BASE_MODEL_NAME,
37
  quantization_config=bnb_config,
38
+ device_map="auto",
39
+ torch_dtype=torch.float16
40
  )
41
 
42
+ print("πŸ”Ή Loading LoRA adapter...")
43
+ model = PeftModel.from_pretrained(base_model_obj, ADAPTER_REPO, subfolder="checkpoint-240")
44
+ model.eval()
45
+
46
+ print("βœ… Model ready and loaded into memory.")
47
 
 
48
  app = FastAPI(title="PROMETHEUS")
49
 
50
  app.add_middleware(
51
  CORSMiddleware,
52
+ allow_origins=["*"],
53
  allow_credentials=True,
54
  allow_methods=["*"],
55
  allow_headers=["*"],
56
  )
57
 
58
+ # Rate Limiter (10 requests/min hard limit)
59
  limiter = Limiter(key_func=get_remote_address)
60
  app.state.limiter = limiter
61
 
62
+ # Soft limit (20 req/min β†’ warning)
63
  request_history = defaultdict(list)
64
+ HISTORY_CLEANUP_INTERVAL = 300
65
+
66
 
67
  async def cleanup_request_history():
68
+ """Background task to clean up old request history entries"""
69
  while True:
70
  await asyncio.sleep(HISTORY_CLEANUP_INTERVAL)
71
  now = time.time()
 
79
  async def startup_event():
80
  asyncio.create_task(cleanup_request_history())
81
 
82
+
83
  @app.exception_handler(RateLimitExceeded)
84
  async def rate_limit_handler(request: Request, exc: RateLimitExceeded):
85
  return JSONResponse(
 
87
  content={"detail": "Rate limit exceeded (10 requests/min). Please wait a bit."},
88
  )
89
 
90
+
91
  class ChatRequest(BaseModel):
92
  message: str = Field(..., min_length=1, max_length=2000)
93
+ user_id: str = Field(default="anonymous", max_length=100)
94
+
95
 
 
96
  @app.post("/chat")
97
  @limiter.limit("10/minute")
98
  async def chat(req: ChatRequest, request: Request):
99
  user_id = req.user_id
100
  message = req.message.strip()
101
+
102
  if not message:
103
+ raise HTTPException(status_code=400, detail="Message cannot be empty.")
104
 
105
+ # Additional length check for safety
106
+ if len(message) > MAX_INPUT_LENGTH:
107
+ raise HTTPException(
108
+ status_code=400,
109
+ detail=f"Message too long. Maximum {MAX_INPUT_LENGTH} characters allowed."
110
+ )
111
+
112
+ # Soft rate limit tracking
113
  now = time.time()
114
  window_start = now - 60
115
+ user_reqs = [t for t in request_history[user_id] if t > window_start]
 
116
  user_reqs.append(now)
117
  request_history[user_id] = user_reqs
118
+
119
  if len(user_reqs) > 20:
120
  return JSONResponse(
121
  status_code=429,
122
  content={"response": "You're sending too many requests β€” please wait a bit."}
123
  )
124
+
125
  try:
126
+ system_prompt = "You are YE a funny AI that embraces mysoginy and misandry."
127
  prompt_text = f"<|im_start|>system\n{system_prompt}<|im_end|>\n"
128
  prompt_text += f"<|im_start|>user\n{message}<|im_end|>\n<|im_start|>assistant\n"
129
+
130
+ inputs = tokenizer(
131
+ prompt_text,
132
+ return_tensors="pt",
133
+ truncation=True,
134
+ max_length=2048 # Prevent extremely long inputs
135
+ ).to("cuda")
136
+
137
+ # Generate without autocast since model is already quantized
138
  output = model.generate(
139
  **inputs,
140
  max_new_tokens=512,
 
144
  do_sample=True,
145
  pad_token_id=tokenizer.eos_token_id,
146
  )
147
+
148
  response = tokenizer.decode(output[0], skip_special_tokens=False)
149
  response = response.split("<|im_start|>assistant")[-1].replace("<|im_end|>", "").strip()
150
+
151
+ # Clean up GPU memory after each generation
152
+ del inputs, output
153
+ torch.cuda.empty_cache()
154
+
155
  return {"response": response}
156
+
157
  except torch.cuda.OutOfMemoryError:
158
+ torch.cuda.empty_cache()
159
  raise HTTPException(status_code=503, detail="Server is overloaded. Please try again later.")
160
  except Exception as e:
161
+ print(f"⚠️ Error generating response: {str(e)}")
162
+ torch.cuda.empty_cache()
163
+ raise HTTPException(status_code=500, detail="Failed to generate response.")
164
+
165
 
166
  @app.get("/health")
167
  async def health_check():
168
+ """Health check endpoint"""
169
+ try:
170
+ gpu_available = torch.cuda.is_available()
171
+ gpu_memory = torch.cuda.get_device_properties(0).total_memory / 1e9 if gpu_available else 0
172
+ return {
173
+ "status": "healthy",
174
+ "model": "loaded",
175
+ "gpu_available": gpu_available,
176
+ "gpu_memory_gb": round(gpu_memory, 2)
177
+ }
178
+ except Exception as e:
179
+ return {"status": "healthy", "model": "loaded", "gpu_info": "unavailable"}
180
+
181
 
182
  if __name__ == "__main__":
183
  uvicorn.run(app, host="0.0.0.0", port=7860)