| import os |
| import gradio as gr |
| import onnxruntime as ort |
| import numpy as np |
| from transformers import AutoTokenizer |
|
|
| ONNX_PATH = os.path.join("assets", "automotive_slm.onnx") |
|
|
| |
| tokenizer = AutoTokenizer.from_pretrained("gpt2") |
| if tokenizer.pad_token is None: |
| tokenizer.pad_token = tokenizer.eos_token |
|
|
| |
| providers = ["CPUExecutionProvider"] |
| so = ort.SessionOptions() |
| so.graph_optimization_level = ort.GraphOptimizationLevel.ORT_ENABLE_ALL |
| session = ort.InferenceSession(ONNX_PATH, providers=providers, sess_options=so) |
|
|
| |
| INPUT_NAME = session.get_inputs()[0].name |
| OUTPUT_NAME = session.get_outputs()[0].name |
|
|
| def generate_onnx(prompt: str, max_tokens=64, temperature=0.8, top_p=0.9, top_k=50) -> str: |
| tokens = tokenizer.encode(prompt) |
| input_ids = np.array([tokens], dtype=np.int64) |
| generated = [] |
|
|
| for _ in range(int(max_tokens)): |
| outputs = session.run([OUTPUT_NAME], {INPUT_NAME: input_ids}) |
| logits = outputs[0][0, -1, :] |
|
|
| |
| if temperature and temperature > 0: |
| logits = logits / max(float(temperature), 1e-6) |
|
|
| |
| if top_k and int(top_k) > 0: |
| k = min(int(top_k), logits.shape[-1]) |
| idx = np.argpartition(logits, -k)[-k:] |
| filt = np.full_like(logits, -np.inf) |
| filt[idx] = logits[idx] |
| logits = filt |
|
|
| |
| exps = np.exp(logits - np.max(logits)) |
| probs = exps / np.sum(exps) |
|
|
| |
| if top_p is not None and 0 < float(top_p) < 1.0: |
| sort_idx = np.argsort(probs)[::-1] |
| sorted_probs = probs[sort_idx] |
| cumsum = np.cumsum(sorted_probs) |
| cutoff = np.searchsorted(cumsum, float(top_p)) + 1 |
| mask = np.zeros_like(probs) |
| keep = sort_idx[:cutoff] |
| mask[keep] = probs[keep] |
| s = mask.sum() |
| if s > 0: |
| probs = mask / s |
|
|
| next_token = int(np.random.choice(len(probs), p=probs)) |
| if next_token == tokenizer.eos_token_id: |
| break |
|
|
| generated.append(next_token) |
| input_ids = np.concatenate([input_ids, [[next_token]]], axis=1) |
|
|
| text = tokenizer.decode(generated, skip_special_tokens=True).strip() |
| if not text: |
| return "I couldn't generate a response." |
| if text.startswith(prompt): |
| text = text[len(prompt):].strip() |
| return text |
|
|
| def chat_fn(message, history, max_tokens, temperature, top_p, top_k): |
| reply = generate_onnx(message, max_tokens, temperature, top_p, top_k) |
| history = history + [ |
| {"role": "user", "content": message}, |
| {"role": "assistant", "content": reply}, |
| ] |
| return history |
|
|
| with gr.Blocks(title="Automotive SLM Chatbot (ONNX)") as demo: |
| gr.Markdown("# 🚗 Automotive SLM Chatbot (ONNX-only)") |
| gr.Markdown("Using model at assets/automotive_slm.onnx") |
|
|
| with gr.Row(): |
| with gr.Column(scale=3): |
| chatbot = gr.Chatbot(label="Chat", height=500, type="messages") |
| msg = gr.Textbox(placeholder="Ask about automotive topics...", label="Your message") |
| with gr.Row(): |
| send_btn = gr.Button("Send", variant="primary") |
| clear_btn = gr.Button("Clear") |
|
|
| with gr.Column(scale=2): |
| gr.Markdown("### Generation settings") |
| max_tokens = gr.Slider(10, 256, value=64, step=1, label="Max tokens") |
| temperature = gr.Slider(0.1, 1.5, value=0.8, step=0.1, label="Temperature") |
| top_p = gr.Slider(0.1, 1.0, value=0.9, step=0.05, label="Top-p") |
| top_k = gr.Slider(1, 100, value=50, step=1, label="Top-k") |
|
|
| send_btn.click( |
| fn=chat_fn, |
| inputs=[msg, chatbot, max_tokens, temperature, top_p, top_k], |
| outputs=[chatbot] |
| ) |
| msg.submit( |
| fn=chat_fn, |
| inputs=[msg, chatbot, max_tokens, temperature, top_p, top_k], |
| outputs=[chatbot] |
| ) |
| clear_btn.click(lambda: [], None, chatbot) |
|
|
| if __name__ == "__main__": |
| demo.launch(server_name="0.0.0.0", server_port=int(os.getenv("PORT", 7860))) |
|
|