santacl commited on
Commit
96e8d35
Β·
verified Β·
1 Parent(s): 34943c3

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +119 -56
app.py CHANGED
@@ -1,70 +1,133 @@
1
- import gradio as gr
2
- from huggingface_hub import InferenceClient
 
 
 
 
 
 
 
 
 
 
 
3
 
 
 
 
 
 
 
4
 
5
- def respond(
6
- message,
7
- history: list[dict[str, str]],
8
- system_message,
9
- max_tokens,
10
- temperature,
11
- top_p,
12
- hf_token: gr.OAuthToken,
13
- ):
14
- """
15
- For more information on `huggingface_hub` Inference API support, please check the docs: https://huggingface.co/docs/huggingface_hub/v0.22.2/en/guides/inference
16
- """
17
- client = InferenceClient(token=hf_token.token, model="openai/gpt-oss-20b")
18
 
19
- messages = [{"role": "system", "content": system_message}]
 
 
20
 
21
- messages.extend(history)
 
22
 
23
- messages.append({"role": "user", "content": message})
 
24
 
25
- response = ""
 
26
 
27
- for message in client.chat_completion(
28
- messages,
29
- max_tokens=max_tokens,
30
- stream=True,
31
- temperature=temperature,
32
- top_p=top_p,
33
- ):
34
- choices = message.choices
35
- token = ""
36
- if len(choices) and choices[0].delta.content:
37
- token = choices[0].delta.content
38
 
39
- response += token
40
- yield response
 
41
 
 
 
 
 
 
 
42
 
43
- """
44
- For information on how to customize the ChatInterface, peruse the gradio docs: https://www.gradio.app/docs/chatinterface
45
- """
46
- chatbot = gr.ChatInterface(
47
- respond,
48
- type="messages",
49
- additional_inputs=[
50
- gr.Textbox(value="You are a friendly Chatbot.", label="System message"),
51
- gr.Slider(minimum=1, maximum=2048, value=512, step=1, label="Max new tokens"),
52
- gr.Slider(minimum=0.1, maximum=4.0, value=0.7, step=0.1, label="Temperature"),
53
- gr.Slider(
54
- minimum=0.1,
55
- maximum=1.0,
56
- value=0.95,
57
- step=0.05,
58
- label="Top-p (nucleus sampling)",
59
- ),
60
- ],
61
- )
62
 
63
- with gr.Blocks() as demo:
64
- with gr.Sidebar():
65
- gr.LoginButton()
66
- chatbot.render()
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
67
 
 
 
 
68
 
69
  if __name__ == "__main__":
70
- demo.launch()
 
1
+ import torch
2
+ from transformers import AutoTokenizer, AutoModelForCausalLM, BitsAndBytesConfig
3
+ from peft import PeftModel
4
+ from fastapi import FastAPI, Request, HTTPException
5
+ from pydantic import BaseModel, Field
6
+ from slowapi import Limiter
7
+ from slowapi.util import get_remote_address
8
+ from slowapi.errors import RateLimitExceeded
9
+ from fastapi.responses import JSONResponse
10
+ import uvicorn
11
+ import time
12
+ from collections import defaultdict
13
+ import asyncio
14
 
15
+ # MODEL CONFIG
16
+ base_model = "cognitivecomputations/dolphin-2.9.3-mistral-nemo-12b"
17
+ adapter_repo = "santacl/septicspo"
18
+ tokenizer = AutoTokenizer.from_pretrained(base_model)
19
+ if tokenizer.pad_token is None:
20
+ tokenizer.pad_token = tokenizer.eos_token
21
 
22
+ bnb_config = BitsAndBytesConfig(
23
+ load_in_4bit=True,
24
+ bnb_4bit_compute_dtype=torch.float16,
25
+ bnb_4bit_use_double_quant=True,
26
+ bnb_4bit_quant_type="nf4"
27
+ )
28
+
29
+ print("Loading base model.")
30
+ base = AutoModelForCausalLM.from_pretrained(
31
+ base_model,
32
+ quantization_config=bnb_config,
33
+ device_map="auto"
34
+ )
35
 
36
+ print("Loading LoRA adapter.")
37
+ model = PeftModel.from_pretrained(base, adapter_repo, subfolder="checkpoint-240")
38
+ print("Model ready")
39
 
40
+ # RATE LIMITER
41
+ app = FastAPI(title="Dolphin-12B-LoRA API")
42
 
43
+ limiter = Limiter(key_func=get_remote_address)
44
+ app.state.limiter = limiter
45
 
46
+ request_history = defaultdict(list)
47
+ HISTORY_CLEANUP_INTERVAL = 300
48
 
49
+ async def cleanup_request_history():
50
+ """Background task to clean up old request history"""
51
+ while True:
52
+ await asyncio.sleep(HISTORY_CLEANUP_INTERVAL)
53
+ now = time.time()
54
+ window_start = now - 60
55
+ for user_id in list(request_history.keys()):
56
+ request_history[user_id] = [t for t in request_history[user_id] if t > window_start]
57
+ if not request_history[user_id]:
58
+ del request_history[user_id]
 
59
 
60
+ @app.on_event("startup")
61
+ async def startup_event():
62
+ asyncio.create_task(cleanup_request_history())
63
 
64
+ @app.exception_handler(RateLimitExceeded)
65
+ async def rate_limit_handler(request: Request, exc: RateLimitExceeded):
66
+ return JSONResponse(
67
+ status_code=429,
68
+ content={"detail": "Rate limit exceeded (10 requests/min). Please wait a bit."},
69
+ )
70
 
71
+ # SCHEMA
72
+ class ChatRequest(BaseModel):
73
+ message: str = Field(..., min_length=1, max_length=2000)
74
+ user_id: str = Field(default="anonymous")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
75
 
76
+ # CHAT ENDPOINT
77
+ @app.post("/chat")
78
+ @limiter.limit("10/minute")
79
+ async def chat(req: ChatRequest, request: Request):
80
+ user_id = req.user_id
81
+ message = req.message.strip()
82
+
83
+ if not message:
84
+ raise HTTPException(status_code=400, detail="Message cannot be empty")
85
+
86
+ # Additional soft rate limit (20 requests/minute)
87
+ now = time.time()
88
+ window_start = now - 60
89
+ user_reqs = request_history[user_id]
90
+ user_reqs = [t for t in user_reqs if t > window_start]
91
+ user_reqs.append(now)
92
+ request_history[user_id] = user_reqs
93
+
94
+ if len(user_reqs) > 20:
95
+ return JSONResponse(
96
+ status_code=429,
97
+ content={"response": "You're sending too many requests β€” please wait a bit."}
98
+ )
99
+
100
+ try:
101
+ system_prompt = "You are Dolphin, a logical, calm, and grounded conversational AI."
102
+ prompt_text = f"<|im_start|>system\n{system_prompt}<|im_end|>\n"
103
+ prompt_text += f"<|im_start|>user\n{message}<|im_end|>\n<|im_start|>assistant\n"
104
+
105
+ inputs = tokenizer(prompt_text, return_tensors="pt").to("cuda")
106
+
107
+ output = model.generate(
108
+ **inputs,
109
+ max_new_tokens=512,
110
+ temperature=0.7,
111
+ top_p=0.9,
112
+ repetition_penalty=1.1,
113
+ do_sample=True,
114
+ pad_token_id=tokenizer.eos_token_id,
115
+ )
116
+
117
+ response = tokenizer.decode(output[0], skip_special_tokens=False)
118
+ response = response.split("<|im_start|>assistant")[-1].replace("<|im_end|>", "").strip()
119
+
120
+ return {"response": response}
121
+
122
+ except torch.cuda.OutOfMemoryError:
123
+ raise HTTPException(status_code=503, detail="Server is overloaded. Please try again later.")
124
+ except Exception as e:
125
+ print(f"Error generating response: {str(e)}")
126
+ raise HTTPException(status_code=500, detail="Failed to generate response")
127
 
128
+ @app.get("/health")
129
+ async def health_check():
130
+ return {"status": "healthy", "model": "loaded"}
131
 
132
  if __name__ == "__main__":
133
+ uvicorn.run(app, host="0.0.0.0", port=7860)