import os import io import json import uuid import base64 import time import random import math from typing import List, Dict, Tuple, Optional import gradio as gr import spaces # Required for ZeroGPU Spaces (@spaces.GPU) # We use the official Ollama Python client for convenience # It respects the OLLAMA_HOST env var, but we will also allow overriding via UI. try: from ollama import Client except Exception as e: raise RuntimeError( "Failed to import the 'ollama' Python client. Ensure it's in requirements.txt." ) from e DEFAULT_PORT = int(os.getenv("PORT", 7860)) DEFAULT_OLLAMA_HOST = os.getenv("OLLAMA_HOST", "").strip() or os.getenv("OLLAMA_BASE_URL", "").strip() or "" DEFAULT_MODEL = os.getenv("OLLAMA_MODEL", "llama3.1") APP_TITLE = "Ollama Chat (Gradio + Docker)" APP_DESCRIPTION = """ A lightweight, fully functional chat UI for Ollama, designed to run on Hugging Face Spaces (Docker). - Bring your own Ollama host (set OLLAMA_HOST in repo secrets or via the UI). - Streamed responses, model management (list/pull), and basic vision support (image input). - Compatible with Spaces ZeroGPU via a @spaces.GPU-decorated function (see GPU Tools panel). """ def ensure_scheme(host: str) -> str: if not host: return host host = host.strip() if not host.startswith(("http://", "https://")): host = "http://" + host # remove trailing slashes while host.endswith("/"): host = host[:-1] return host def get_client(host: str) -> Client: host = ensure_scheme(host) if not host: # fall back to environment-configured client; Client() picks up OLLAMA_HOST if set return Client() return Client(host=host) def list_models(host: str) -> Tuple[List[str], Optional[str]]: try: client = get_client(host) data = client.list() # {'models': [{'name': 'llama3:latest', ...}, ...]} names = sorted(m.get("name", "") for m in data.get("models", []) if m.get("name")) return names, None except Exception as e: return [], f"Unable to list models from {host or '(env default)'}: {e}" def test_connection(host: str) -> Tuple[bool, str]: names, err = list_models(host) if err: return False, err if not names: return True, f"Connected to {host or '(env default)'} but no models found. Pull one to continue." return True, f"Connected to {host or '(env default)'}; found {len(names)} models." def show_model(host: str, model: str) -> Tuple[Optional[dict], Optional[str]]: try: client = get_client(host) info = client.show(model=model) return info, None except Exception as e: return None, f"Unable to show model '{model}': {e}" def pull_model(host: str, model: str): """ Generator that pulls a model on the remote Ollama host, yielding progress strings. """ if not model: yield "Provide a model name to pull (e.g., llama3.1, mistral, qwen2.5:latest)" return try: client = get_client(host) already, _ = show_model(host, model) if already: yield f"Model '{model}' already present on the host." return yield f"Pulling '{model}' from registry..." for part in client.pull(model=model, stream=True): # part has keys: status, digest, total, completed, etc. status = part.get("status", "") total = part.get("total", 0) completed = part.get("completed", 0) pct = f"{(completed / total * 100):.1f}%" if total else "" line = status if pct: line += f" ({pct})" yield line yield f"Finished pulling '{model}'." except Exception as e: yield f"Error pulling '{model}': {e}" def encode_image_to_base64(path: str) -> Optional[str]: try: with open(path, "rb") as f: return base64.b64encode(f.read()).decode("utf-8") except Exception: return None def build_ollama_messages( system_prompt: str, convo_messages: List[Dict], # stored chat history as Ollama-style messages user_text: str, image_paths: Optional[List[str]] = None, ) -> List[Dict]: """ Returns the full message list to send to Ollama, including system prompt (if provided), past conversation, and the new user message. """ messages = [] if system_prompt.strip(): messages.append({"role": "system", "content": system_prompt.strip()}) messages.extend(convo_messages or []) msg: Dict = {"role": "user", "content": user_text or ""} if image_paths: images_b64 = [] for p in image_paths: b64 = encode_image_to_base64(p) if b64: images_b64.append(b64) if images_b64: msg["images"] = images_b64 messages.append(msg) return messages def messages_for_chatbot( text: str, image_paths: Optional[List[str]] = None, role: str = "user", ) -> Dict: """ Build a Gradio Chatbot message in "messages" mode: {"role": "user"|"assistant", "content": [{"type":"text","text":...}, {"type":"image","image":}, ...]} """ content = [] t = (text or "").strip() if t: content.append({"type": "text", "text": t}) if image_paths: # Only embed small previews; Gradio will load images from file path. for p in image_paths: try: # Gradio accepts PIL.Image or path. Provide path for simplicity. content.append({"type": "image", "image": p}) except Exception: continue return {"role": role, "content": content if content else [{"type": "text", "text": ""}]} def stream_chat( host: str, model: str, system_prompt: str, temperature: float, top_p: float, top_k: int, repeat_penalty: float, num_ctx: int, max_tokens: Optional[int], seed: Optional[int], convo_messages: List[Dict], chatbot_history: List[Dict], user_text: str, image_files: Optional[List[str]], ): """ Stream a chat completion from Ollama and update Gradio Chatbot incrementally. """ # 1) Add user message to chatbot and state user_msg_for_bot = messages_for_chatbot(user_text, image_files, role="user") chatbot_history = chatbot_history + [user_msg_for_bot] # 2) Build messages for Ollama ollama_messages = build_ollama_messages(system_prompt, convo_messages, user_text, image_files) # 3) Prepare options options = { "temperature": temperature, "top_p": top_p, "top_k": top_k, "repeat_penalty": repeat_penalty, "num_ctx": num_ctx, } if max_tokens is not None and max_tokens > 0: # Some backends expect "num_predict"; ensure compatibility options["num_predict"] = max_tokens if seed is not None: options["seed"] = seed # 4) Start streaming client = get_client(host) assistant_text_accum = "" start_time = time.time() # Prepare assistant placeholder in Chatbot assistant_msg_for_bot = messages_for_chatbot("", None, role="assistant") chatbot_history = chatbot_history + [assistant_msg_for_bot] status_md = f"Model: {model} | Host: {ensure_scheme(host) or '(env default)'} | Streaming..." # Initial yield to display user msg and assistant placeholder yield chatbot_history, status_md, convo_messages try: for part in client.chat( model=model, messages=ollama_messages, stream=True, options=options, ): msg = part.get("message", {}) or {} delta = msg.get("content", "") if delta: assistant_text_accum += delta chatbot_history[-1] = messages_for_chatbot(assistant_text_accum, None, role="assistant") done = part.get("done", False) if done: eval_count = part.get("eval_count", 0) prompt_eval_count = part.get("prompt_eval_count", 0) total = time.time() - start_time tok_s = (eval_count / total) if total > 0 else 0.0 status_md = ( f"Model: {model} | Host: {ensure_scheme(host) or '(env default)'} | " f"Prompt tokens: {prompt_eval_count} | Output tokens: {eval_count} | " f"Time: {total:.2f}s | Speed: {tok_s:.1f} tok/s" ) yield chatbot_history, status_md, convo_messages # 5) Save to conversation state: add the final user+assistant to convo_messages convo_messages = convo_messages + [ { "role": "user", "content": user_text or "", **( { "images": [ b for p in (image_files or []) for b in ([encode_image_to_base64(p)] if encode_image_to_base64(p) else []) ] } if image_files else {} ), }, {"role": "assistant", "content": assistant_text_accum}, ] yield chatbot_history, status_md, convo_messages except Exception as e: err_msg = f"Error during generation: {e}" chatbot_history[-1] = messages_for_chatbot(err_msg, None, role="assistant") yield chatbot_history, err_msg, convo_messages def clear_conversation(): return [], [], "" def export_conversation(history: List[Dict], convo_messages: List[Dict]) -> Tuple[str, str]: export_blob = { "chat_messages": history, "ollama_messages": convo_messages, "meta": { "title": APP_TITLE, "exported_at": time.strftime("%Y-%m-%d %H:%M:%S", time.gmtime()), "version": "1.1", }, } path = f"chat_export_{int(time.time())}.json" with open(path, "w", encoding="utf-8") as f: json.dump(export_blob, f, ensure_ascii=False, indent=2) return path, f"Exported {len(history)} messages to {path}" # ---------------------- ZeroGPU support: define a GPU-decorated function ---------------------- @spaces.GPU def gpu_ping(workload: int = 256) -> dict: """ Minimal function to satisfy ZeroGPU Spaces requirement and optionally exercise the GPU. If torch with CUDA is available, perform a tiny matmul on GPU; otherwise do a CPU loop. """ t0 = time.time() # Light CPU math as fallback acc = 0.0 for i in range(max(1, workload)): x = random.random() * 1000.0 # harmless math; avoids dependency on numpy s = math.sin(x) c = math.cos(x) t = math.tan(x) if abs(math.cos(x)) > 1e-9 else 1.0 acc += s * c / t info = {"mode": "cpu", "ops": workload} # Optional CUDA check (torch not required) try: import torch # noqa: F401 if torch.cuda.is_available(): a = torch.randn((256, 256), device="cuda") b = torch.mm(a, a) _ = float(b.mean().item()) info["mode"] = "cuda" info["device"] = torch.cuda.get_device_name(torch.cuda.current_device()) info["cuda"] = True else: info["cuda"] = False except Exception: # torch not installed or other issue; still fine for ZeroGPU detection info["cuda"] = "unavailable" elapsed = time.time() - t0 return {"ok": True, "elapsed_s": round(elapsed, 4), "acc_checksum": float(acc % 1.0), "info": info} # --------------------------------------------------------------------------------------------- def ui() -> gr.Blocks: with gr.Blocks(title=APP_TITLE, theme=gr.themes.Soft()) as demo: gr.Markdown(f"# {APP_TITLE}") gr.Markdown(APP_DESCRIPTION) # States state_convo = gr.State([]) # stores ollama-format convo (no system prompt) state_history = gr.State([]) # stores Chatbot messages (messages-mode) state_system_prompt = gr.State("") state_host = gr.State(DEFAULT_OLLAMA_HOST) state_session = gr.State(str(uuid.uuid4())) with gr.Row(): with gr.Column(scale=3): chatbot = gr.Chatbot(label="Chat", type="messages", height=520, avatar_images=(None, None)) with gr.Row(): txt = gr.Textbox( label="Your message", placeholder="Ask anything...", autofocus=True, scale=4, ) image_files = gr.Files( label="Optional image(s)", file_types=["image"], type="filepath", visible=True, ) with gr.Row(): send_btn = gr.Button("Send", variant="primary") stop_btn = gr.Button("Stop") clear_btn = gr.Button("Clear") export_btn = gr.Button("Export") status = gr.Markdown("Ready.", elem_id="status_box") with gr.Column(scale=2): gr.Markdown("## Connection") host_in = gr.Textbox( label="Ollama Host URL", placeholder="http://127.0.0.1:11434 (or leave blank to use server env OLLAMA_HOST)", value=DEFAULT_OLLAMA_HOST, ) with gr.Row(): test_btn = gr.Button("Test Connection") refresh_models_btn = gr.Button("Refresh Models") models_dd = gr.Dropdown( choices=[], value=None, label="Model", allow_custom_value=True, info="Select a model from the server or type a name (e.g., llama3.1, mistral, phi4:latest)", ) pull_model_txt = gr.Textbox( label="Pull Model (by name)", placeholder="e.g., llama3.1, mistral, qwen2.5:latest", ) pull_btn = gr.Button("Pull Model") pull_log = gr.Textbox(label="Pull Progress", interactive=False, lines=6) gr.Markdown("## System Prompt") sys_prompt = gr.Textbox( label="System Prompt", placeholder="You are a helpful assistant...", lines=4, value=os.getenv("SYSTEM_PROMPT", ""), ) gr.Markdown("## Generation Settings") with gr.Row(): temperature = gr.Slider(0.0, 2.0, value=0.7, step=0.05, label="Temperature") top_p = gr.Slider(0.0, 1.0, value=0.9, step=0.01, label="Top-p") with gr.Row(): top_k = gr.Slider(0, 200, value=40, step=1, label="Top-k") repeat_penalty = gr.Slider(0.0, 2.0, value=1.1, step=0.01, label="Repeat Penalty") with gr.Row(): num_ctx = gr.Slider(256, 8192, value=4096, step=256, label="Context Window (num_ctx)") max_tokens = gr.Slider(0, 8192, value=0, step=16, label="Max New Tokens (0 = auto)") seed = gr.Number(value=None, label="Seed (optional)", precision=0) gr.Markdown("## GPU Tools (ZeroGPU compatible)") with gr.Row(): gpu_workload = gr.Slider(64, 4096, value=256, step=64, label="GPU Ping Workload") with gr.Row(): gpu_btn = gr.Button("Run GPU Ping") gpu_out = gr.Textbox(label="GPU Ping Result", lines=6, interactive=False) # Wire up actions def _on_load(): # Initialize models list based on default host host = DEFAULT_OLLAMA_HOST names, err = list_models(host) if err: status_msg = f"Note: {err}" else: status_msg = f"Loaded {len(names)} models from {ensure_scheme(host) or '(env default)'}." # If DEFAULT_MODEL is available select it otherwise pick first value = DEFAULT_MODEL if DEFAULT_MODEL in names else (names[0] if names else None) return ( names, value, # models_dd host, # host_in status_msg, # status [], [], "", # state_history, state_convo, system prompt state ) load_outputs = [ models_dd, models_dd, host_in, status, state_history, state_convo, state_system_prompt ] demo.load(_on_load, outputs=load_outputs) # When host changes, update state_host def set_host(h): return ensure_scheme(h) host_in.change(set_host, inputs=host_in, outputs=state_host) # Test connection def _test(h): ok, msg = test_connection(h) # refresh models if ok names, err = list_models(h) if ok else ([], None) model_val = models_dd.value if ok and models_dd.value in names else (names[0] if names else None) if err: msg += f"\nAlso: {err}" return names, model_val, msg test_btn.click(_test, inputs=host_in, outputs=[models_dd, models_dd, status]) # Refresh models refresh_models_btn.click(_test, inputs=host_in, outputs=[models_dd, models_dd, status]) # Pull model progress def _pull(h, name): if not name: yield "Please enter a model name to pull." return for line in pull_model(h, name.strip()): yield line pull_btn.click(_pull, inputs=[host_in, pull_model_txt], outputs=pull_log) # Clear conversation clear_btn.click(clear_conversation, outputs=[chatbot, state_convo, status]) # Export export_file = gr.File(label="Download Conversation", visible=True) export_btn.click(export_conversation, inputs=[state_history, state_convo], outputs=[export_file, status]) # Send/Stream def _submit( h, m, sp, t, tp, tk, rp, ctx, mx, sd, convo, history, text, files ): # Convert mx slider 0 -> None (auto) mx_int = int(mx) if mx and int(mx) > 0 else None sd_int = int(sd) if sd is not None else None yield from stream_chat( host=h, model=m or DEFAULT_MODEL, system_prompt=sp or "", temperature=float(t), top_p=float(tp), top_k=int(tk), repeat_penalty=float(rp), num_ctx=int(ctx), max_tokens=mx_int, seed=sd_int, convo_messages=convo, chatbot_history=history, user_text=text, image_files=files, ) submit_event = send_btn.click( _submit, inputs=[host_in, models_dd, sys_prompt, temperature, top_p, top_k, repeat_penalty, num_ctx, max_tokens, seed, state_convo, state_history, txt, image_files], outputs=[chatbot, status, state_convo], ) # Pressing Enter in the textbox also triggers submit txt.submit( _submit, inputs=[host_in, models_dd, sys_prompt, temperature, top_p, top_k, repeat_penalty, num_ctx, max_tokens, seed, state_convo, state_history, txt, image_files], outputs=[chatbot, status, state_convo], ) # Stop streaming stop_btn.click(None, None, None, cancels=[submit_event]) # After successful send, clear the input box and keep images cleared def _post_send(): return "", None send_btn.click(_post_send, outputs=[txt, image_files]) txt.submit(_post_send, outputs=[txt, image_files]) # Keep Chatbot state in sync (so export works) def _sync_chatbot_state(history): return history chatbot.change(_sync_chatbot_state, inputs=chatbot, outputs=state_history) # GPU Ping hook def _gpu_ping_ui(n): try: res = gpu_ping(int(n)) try: return json.dumps(res, indent=2) except Exception: return str(res) except Exception as e: return f"GPU ping failed: {e}" gpu_btn.click(_gpu_ping_ui, inputs=[gpu_workload], outputs=[gpu_out]) return demo if __name__ == "__main__": demo = ui() demo.queue(default_concurrency_limit=10) demo.launch(server_name="0.0.0.0", server_port=DEFAULT_PORT, show_api=True, ssr_mode=False)