| import os |
| from typing import List, Literal, Optional |
|
|
| import torch |
| from fastapi import FastAPI |
| from pydantic import BaseModel, Field |
| from transformers import AutoModelForCausalLM, AutoTokenizer |
|
|
| |
| |
| |
| MODEL_NAME = os.getenv("MODEL_NAME", "MBZUAI-Paris/Nile-Chat-12B") |
|
|
| MAX_MAX_NEW_TOKENS = 2048 |
| DEFAULT_MAX_NEW_TOKENS = 1024 |
| MAX_INPUT_TOKEN_LENGTH = int(os.getenv("MAX_INPUT_TOKEN_LENGTH", "2024")) |
|
|
| app = FastAPI(title="Nile-Chat-12B FastAPI") |
|
|
| tokenizer = None |
| model = None |
|
|
|
|
| |
| |
| |
| Role = Literal["system", "user", "assistant"] |
|
|
| class ChatMessage(BaseModel): |
| role: Role |
| content: str |
|
|
| class GenerateRequest(BaseModel): |
| |
| |
| messages: List[ChatMessage] = Field(..., description="Conversation messages in OpenAI-like format") |
|
|
| max_new_tokens: int = Field(DEFAULT_MAX_NEW_TOKENS, ge=1, le=MAX_MAX_NEW_TOKENS) |
| do_sample: bool = True |
| temperature: float = Field(0.6, ge=0.0, le=4.0) |
| top_p: float = Field(0.9, ge=0.05, le=1.0) |
| top_k: int = Field(50, ge=1, le=1000) |
| repetition_penalty: float = Field(1.1, ge=1.0, le=2.0) |
|
|
|
|
| class GenerateResponse(BaseModel): |
| response: str |
| trimmed: bool = False |
| model: str = MODEL_NAME |
|
|
|
|
| |
| |
| |
| @app.on_event("startup") |
| def startup_event(): |
| global tokenizer, model |
|
|
| tokenizer = AutoTokenizer.from_pretrained(MODEL_NAME) |
|
|
| |
| dtype = torch.bfloat16 if torch.cuda.is_available() else torch.float32 |
|
|
| model = AutoModelForCausalLM.from_pretrained( |
| MODEL_NAME, |
| device_map="auto", |
| torch_dtype=dtype, |
| ) |
| model.eval() |
|
|
| print("Model ready") |
|
|
|
|
| @app.get("/health") |
| def health(): |
| return {"status": "ok", "model": MODEL_NAME} |
|
|
|
|
| |
| |
| |
| @app.post("/generate", response_model=GenerateResponse) |
| def generate(req: GenerateRequest): |
| global tokenizer, model |
|
|
| if not req.messages: |
| return GenerateResponse(response="Error: messages is empty", trimmed=False) |
|
|
| |
| conversation = [m.model_dump() for m in req.messages] |
|
|
| |
| input_ids = tokenizer.apply_chat_template( |
| conversation, |
| add_generation_prompt=True, |
| return_tensors="pt" |
| ) |
|
|
| trimmed = False |
| if input_ids.shape[1] > MAX_INPUT_TOKEN_LENGTH: |
| input_ids = input_ids[:, -MAX_INPUT_TOKEN_LENGTH:] |
| trimmed = True |
|
|
| input_ids = input_ids.to(model.device) |
|
|
| |
| last_user = next((m.content for m in reversed(req.messages) if m.role == "user"), "") |
| print("\n=== Incoming Request ===") |
| print("MODEL:", MODEL_NAME) |
| print("LAST USER:", last_user) |
| print("trimmed_input:", trimmed) |
| print("input_tokens:", int(input_ids.shape[1])) |
|
|
| |
| with torch.no_grad(): |
| out = model.generate( |
| input_ids=input_ids, |
| max_new_tokens=req.max_new_tokens, |
| do_sample=req.do_sample, |
| top_p=req.top_p, |
| top_k=req.top_k, |
| temperature=req.temperature, |
| num_beams=1, |
| repetition_penalty=req.repetition_penalty, |
| ) |
|
|
| |
| new_tokens = out[0, input_ids.shape[-1]:] |
| response_text = tokenizer.decode(new_tokens, skip_special_tokens=True).strip() |
|
|
| print("\n=== Model Response ===") |
| print(response_text) |
| print("======================\n") |
|
|
| return GenerateResponse(response=response_text, trimmed=trimmed, model=MODEL_NAME) |
|
|