Files changed (1) hide show
  1. app.py +497 -0
app.py ADDED
@@ -0,0 +1,497 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # file: app.py
2
+ import os
3
+ import io
4
+ import json
5
+ import uuid
6
+ import base64
7
+ import time
8
+ from typing import List, Dict, Tuple, Optional
9
+
10
+ import gradio as gr
11
+
12
+ # We use the official Ollama Python client for convenience
13
+ # It respects the OLLAMA_HOST env var, but we will also allow overriding via UI.
14
+ try:
15
+ from ollama import Client
16
+ except Exception as e:
17
+ raise RuntimeError(
18
+ "Failed to import the 'ollama' Python client. Ensure it's in requirements.txt."
19
+ ) from e
20
+
21
+ DEFAULT_PORT = int(os.getenv("PORT", 7860))
22
+ DEFAULT_OLLAMA_HOST = os.getenv("OLLAMA_HOST", "").strip() or os.getenv("OLLAMA_BASE_URL", "").strip() or ""
23
+ DEFAULT_MODEL = os.getenv("OLLAMA_MODEL", "llama3.1")
24
+ APP_TITLE = "Ollama Chat (Gradio + Docker)"
25
+ APP_DESCRIPTION = """
26
+ A lightweight, fully functional chat UI for Ollama, designed to run on Hugging Face Spaces (Docker).
27
+ - Bring your own Ollama host (set OLLAMA_HOST in repo secrets or via the UI).
28
+ - Streamed responses, model management (list/pull), and basic vision support (image input).
29
+ """
30
+
31
+
32
+ def ensure_scheme(host: str) -> str:
33
+ if not host:
34
+ return host
35
+ host = host.strip()
36
+ if not host.startswith(("http://", "https://")):
37
+ host = "http://" + host
38
+ # remove trailing slashes
39
+ while host.endswith("/"):
40
+ host = host[:-1]
41
+ return host
42
+
43
+
44
+ def get_client(host: str) -> Client:
45
+ host = ensure_scheme(host)
46
+ if not host:
47
+ # fall back to environment-configured client; Client() picks up OLLAMA_HOST if set
48
+ return Client()
49
+ return Client(host=host)
50
+
51
+
52
+ def list_models(host: str) -> Tuple[List[str], Optional[str]]:
53
+ try:
54
+ client = get_client(host)
55
+ data = client.list() # {'models': [{'name': 'llama3:latest', ...}, ...]}
56
+ names = sorted(m.get("name", "") for m in data.get("models", []) if m.get("name"))
57
+ return names, None
58
+ except Exception as e:
59
+ return [], f"Unable to list models from {host or '(env default)'}: {e}"
60
+
61
+
62
+ def test_connection(host: str) -> Tuple[bool, str]:
63
+ names, err = list_models(host)
64
+ if err:
65
+ return False, err
66
+ if not names:
67
+ return True, f"Connected to {host or '(env default)'} but no models found. Pull one to continue."
68
+ return True, f"Connected to {host or '(env default)'}; found {len(names)} models."
69
+
70
+
71
+ def show_model(host: str, model: str) -> Tuple[Optional[dict], Optional[str]]:
72
+ try:
73
+ client = get_client(host)
74
+ info = client.show(model=model)
75
+ return info, None
76
+ except Exception as e:
77
+ return None, f"Unable to show model '{model}': {e}"
78
+
79
+
80
+ def pull_model(host: str, model: str):
81
+ """
82
+ Generator that pulls a model on the remote Ollama host, yielding progress strings.
83
+ """
84
+ if not model:
85
+ yield "Provide a model name to pull (e.g., llama3.1, mistral, qwen2.5:latest)"
86
+ return
87
+ try:
88
+ client = get_client(host)
89
+ already, _ = show_model(host, model)
90
+ if already:
91
+ yield f"Model '{model}' already present on the host."
92
+ return
93
+
94
+ yield f"Pulling '{model}' from registry..."
95
+ for part in client.pull(model=model, stream=True):
96
+ # part has keys: status, digest, total, completed, etc.
97
+ status = part.get("status", "")
98
+ total = part.get("total", 0)
99
+ completed = part.get("completed", 0)
100
+ pct = f"{(completed / total * 100):.1f}%" if total else ""
101
+ line = status
102
+ if pct:
103
+ line += f" ({pct})"
104
+ yield line
105
+ yield f"Finished pulling '{model}'."
106
+ except Exception as e:
107
+ yield f"Error pulling '{model}': {e}"
108
+
109
+
110
+ def encode_image_to_base64(path: str) -> Optional[str]:
111
+ try:
112
+ with open(path, "rb") as f:
113
+ return base64.b64encode(f.read()).decode("utf-8")
114
+ except Exception:
115
+ return None
116
+
117
+
118
+ def build_ollama_messages(
119
+ system_prompt: str,
120
+ convo_messages: List[Dict], # stored chat history as Ollama-style messages
121
+ user_text: str,
122
+ image_paths: Optional[List[str]] = None,
123
+ ) -> List[Dict]:
124
+ """
125
+ Returns the full message list to send to Ollama, including system prompt (if provided),
126
+ past conversation, and the new user message.
127
+ """
128
+ messages = []
129
+ if system_prompt.strip():
130
+ messages.append({"role": "system", "content": system_prompt.strip()})
131
+
132
+ messages.extend(convo_messages or [])
133
+
134
+ msg: Dict = {"role": "user", "content": user_text or ""}
135
+ if image_paths:
136
+ images_b64 = []
137
+ for p in image_paths:
138
+ b64 = encode_image_to_base64(p)
139
+ if b64:
140
+ images_b64.append(b64)
141
+ if images_b64:
142
+ msg["images"] = images_b64
143
+ messages.append(msg)
144
+ return messages
145
+
146
+
147
+ def messages_for_chatbot(
148
+ text: str,
149
+ image_paths: Optional[List[str]] = None,
150
+ role: str = "user",
151
+ ) -> Dict:
152
+ """
153
+ Build a Gradio Chatbot message in "messages" mode:
154
+ {"role": "user"|"assistant", "content": [{"type":"text","text":...}, {"type":"image","image":<PIL.Image>}, ...]}
155
+ """
156
+ content = []
157
+ t = (text or "").strip()
158
+ if t:
159
+ content.append({"type": "text", "text": t})
160
+
161
+ if image_paths:
162
+ # Only embed small previews; Gradio will load images from file path.
163
+ for p in image_paths:
164
+ try:
165
+ # Gradio accepts PIL.Image or path. Provide path for simplicity.
166
+ content.append({"type": "image", "image": p})
167
+ except Exception:
168
+ continue
169
+ return {"role": role, "content": content if content else [{"type": "text", "text": ""}]}
170
+
171
+
172
+ def stream_chat(
173
+ host: str,
174
+ model: str,
175
+ system_prompt: str,
176
+ temperature: float,
177
+ top_p: float,
178
+ top_k: int,
179
+ repeat_penalty: float,
180
+ num_ctx: int,
181
+ max_tokens: Optional[int],
182
+ seed: Optional[int],
183
+ convo_messages: List[Dict],
184
+ chatbot_history: List[Dict],
185
+ user_text: str,
186
+ image_files: Optional[List[str]],
187
+ ):
188
+ """
189
+ Stream a chat completion from Ollama and update Gradio Chatbot incrementally.
190
+ """
191
+ # 1) Add user message to chatbot and state
192
+ user_msg_for_bot = messages_for_chatbot(user_text, image_files, role="user")
193
+ chatbot_history = chatbot_history + [user_msg_for_bot]
194
+
195
+ # 2) Build messages for Ollama
196
+ ollama_messages = build_ollama_messages(system_prompt, convo_messages, user_text, image_files)
197
+
198
+ # 3) Prepare options
199
+ options = {
200
+ "temperature": temperature,
201
+ "top_p": top_p,
202
+ "top_k": top_k,
203
+ "repeat_penalty": repeat_penalty,
204
+ "num_ctx": num_ctx,
205
+ }
206
+ if max_tokens is not None and max_tokens > 0:
207
+ # Some backends expect "num_predict"; ensure compatibility
208
+ options["num_predict"] = max_tokens
209
+ if seed is not None:
210
+ options["seed"] = seed
211
+
212
+ # 4) Start streaming
213
+ client = get_client(host)
214
+ assistant_text_accum = ""
215
+ start_time = time.time()
216
+
217
+ # Prepare assistant placeholder in Chatbot
218
+ assistant_msg_for_bot = messages_for_chatbot("", None, role="assistant")
219
+ chatbot_history = chatbot_history + [assistant_msg_for_bot]
220
+ status_md = f"Model: {model} | Host: {ensure_scheme(host) or '(env default)'} | Streaming..."
221
+
222
+ # Initial yield to display user msg and assistant placeholder
223
+ yield chatbot_history, status_md, convo_messages
224
+
225
+ try:
226
+ for part in client.chat(
227
+ model=model,
228
+ messages=ollama_messages,
229
+ stream=True,
230
+ options=options,
231
+ ):
232
+ # The streaming responses from ollama look like:
233
+ # {'model': '...', 'created_at': '...', 'message': {'role': 'assistant','content':'...'}, 'done': False}
234
+ msg = part.get("message", {}) or {}
235
+ delta = msg.get("content", "")
236
+ if delta:
237
+ assistant_text_accum += delta
238
+ chatbot_history[-1] = messages_for_chatbot(assistant_text_accum, None, role="assistant")
239
+
240
+ # Update status with token counts if present
241
+ done = part.get("done", False)
242
+ if done:
243
+ # End-of-stream stats
244
+ eval_count = part.get("eval_count", 0)
245
+ prompt_eval_count = part.get("prompt_eval_count", 0)
246
+ total = time.time() - start_time
247
+ tok_s = (eval_count / total) if total > 0 else 0.0
248
+ status_md = (
249
+ f"Model: {model} | Host: {ensure_scheme(host) or '(env default)'} | "
250
+ f"Prompt tokens: {prompt_eval_count} | Output tokens: {eval_count} | "
251
+ f"Time: {total:.2f}s | Speed: {tok_s:.1f} tok/s"
252
+ )
253
+ yield chatbot_history, status_md, convo_messages
254
+
255
+ # 5) Save to conversation state: add the final user+assistant to convo_messages
256
+ # We add only the messages belonging to the conversation (no 'system' here)
257
+ convo_messages = convo_messages + [
258
+ {"role": "user", "content": user_text or "", **({"images": [encode_image_to_base64(p) for p in (image_files or []) if encode_image_to_base64(p)]} if image_files else {})},
259
+ {"role": "assistant", "content": assistant_text_accum},
260
+ ]
261
+
262
+ yield chatbot_history, status_md, convo_messages
263
+
264
+ except Exception as e:
265
+ # Show error inline
266
+ err_msg = f"Error during generation: {e}"
267
+ chatbot_history[-1] = messages_for_chatbot(err_msg, None, role="assistant")
268
+ yield chatbot_history, err_msg, convo_messages
269
+
270
+
271
+ def clear_conversation():
272
+ return [], [], ""
273
+
274
+
275
+ def export_conversation(history: List[Dict], convo_messages: List[Dict]) -> Tuple[str, str]:
276
+ # Export both the chat UI messages and the raw ollama messages
277
+ export_blob = {
278
+ "chat_messages": history,
279
+ "ollama_messages": convo_messages,
280
+ "meta": {
281
+ "title": APP_TITLE,
282
+ "exported_at": time.strftime("%Y-%m-%d %H:%M:%S", time.gmtime()),
283
+ "version": "1.0",
284
+ },
285
+ }
286
+ path = f"chat_export_{int(time.time())}.json"
287
+ with open(path, "w", encoding="utf-8") as f:
288
+ json.dump(export_blob, f, ensure_ascii=False, indent=2)
289
+ return path, f"Exported {len(history)} messages to {path}"
290
+
291
+
292
+ def ui() -> gr.Blocks:
293
+ with gr.Blocks(title=APP_TITLE, theme=gr.themes.Soft()) as demo:
294
+ gr.Markdown(f"# {APP_TITLE}")
295
+ gr.Markdown(APP_DESCRIPTION)
296
+
297
+ # States
298
+ state_convo = gr.State([]) # stores ollama-format convo (no system prompt)
299
+ state_history = gr.State([]) # stores Chatbot messages (messages-mode)
300
+ state_system_prompt = gr.State("")
301
+ state_host = gr.State(DEFAULT_OLLAMA_HOST)
302
+ state_session = gr.State(str(uuid.uuid4()))
303
+
304
+ with gr.Row():
305
+ with gr.Column(scale=3):
306
+ chatbot = gr.Chatbot(label="Chat", type="messages", height=520, avatar_images=(None, None))
307
+ with gr.Row():
308
+ txt = gr.Textbox(
309
+ label="Your message",
310
+ placeholder="Ask anything...",
311
+ autofocus=True,
312
+ scale=4,
313
+ )
314
+ image_files = gr.Files(
315
+ label="Optional image(s)",
316
+ file_types=["image"],
317
+ type="filepath",
318
+ visible=True,
319
+ )
320
+ with gr.Row():
321
+ send_btn = gr.Button("Send", variant="primary")
322
+ stop_btn = gr.Button("Stop")
323
+ clear_btn = gr.Button("Clear")
324
+ export_btn = gr.Button("Export")
325
+
326
+ status = gr.Markdown("Ready.", elem_id="status_box")
327
+
328
+ with gr.Column(scale=2):
329
+ gr.Markdown("## Connection")
330
+ host_in = gr.Textbox(
331
+ label="Ollama Host URL",
332
+ placeholder="http://127.0.0.1:11434 (or leave blank to use server env OLLAMA_HOST)",
333
+ value=DEFAULT_OLLAMA_HOST,
334
+ )
335
+ with gr.Row():
336
+ test_btn = gr.Button("Test Connection")
337
+ refresh_models_btn = gr.Button("Refresh Models")
338
+
339
+ models_dd = gr.Dropdown(
340
+ choices=[],
341
+ value=None,
342
+ label="Model",
343
+ allow_custom_value=True,
344
+ info="Select a model from the server or type a name (e.g., llama3.1, mistral, phi4:latest)",
345
+ )
346
+ pull_model_txt = gr.Textbox(
347
+ label="Pull Model (by name)",
348
+ placeholder="e.g., llama3.1, mistral, qwen2.5:latest",
349
+ )
350
+ pull_btn = gr.Button("Pull Model")
351
+ pull_log = gr.Textbox(label="Pull Progress", interactive=False, lines=6)
352
+
353
+ gr.Markdown("## System Prompt")
354
+ sys_prompt = gr.Textbox(
355
+ label="System Prompt",
356
+ placeholder="You are a helpful assistant...",
357
+ lines=4,
358
+ value=os.getenv("SYSTEM_PROMPT", ""),
359
+ )
360
+
361
+ gr.Markdown("## Generation Settings")
362
+ with gr.Row():
363
+ temperature = gr.Slider(0.0, 2.0, value=0.7, step=0.05, label="Temperature")
364
+ top_p = gr.Slider(0.0, 1.0, value=0.9, step=0.01, label="Top-p")
365
+ with gr.Row():
366
+ top_k = gr.Slider(0, 200, value=40, step=1, label="Top-k")
367
+ repeat_penalty = gr.Slider(0.0, 2.0, value=1.1, step=0.01, label="Repeat Penalty")
368
+ with gr.Row():
369
+ num_ctx = gr.Slider(256, 8192, value=4096, step=256, label="Context Window (num_ctx)")
370
+ max_tokens = gr.Slider(0, 8192, value=0, step=16, label="Max New Tokens (0 = auto)")
371
+ seed = gr.Number(value=None, label="Seed (optional)", precision=0)
372
+
373
+ # Wire up actions
374
+ def _on_load():
375
+ # Initialize models list based on default host
376
+ host = DEFAULT_OLLAMA_HOST
377
+ names, err = list_models(host)
378
+ if err:
379
+ status_msg = f"Note: {err}"
380
+ else:
381
+ status_msg = f"Loaded {len(names)} models from {ensure_scheme(host) or '(env default)'}."
382
+ # If DEFAULT_MODEL is available select it otherwise pick first
383
+ value = DEFAULT_MODEL if DEFAULT_MODEL in names else (names[0] if names else None)
384
+ return (
385
+ names, value, # models_dd
386
+ host, # host_in
387
+ status_msg, # status
388
+ [], [], "", # state_history, state_convo, system prompt state
389
+ )
390
+
391
+ load_outputs = [
392
+ models_dd, models_dd,
393
+ host_in,
394
+ status,
395
+ state_history, state_convo, state_system_prompt
396
+ ]
397
+ demo.load(_on_load, outputs=load_outputs)
398
+
399
+ # When host changes, update state_host
400
+ def set_host(h):
401
+ return ensure_scheme(h)
402
+
403
+ host_in.change(set_host, inputs=host_in, outputs=state_host)
404
+
405
+ # Test connection
406
+ def _test(h):
407
+ ok, msg = test_connection(h)
408
+ # refresh models if ok
409
+ names, err = list_models(h) if ok else ([], None)
410
+ model_val = models_dd.value if ok and models_dd.value in names else (names[0] if names else None)
411
+ if err:
412
+ msg += f"\nAlso: {err}"
413
+ return names, model_val, msg
414
+
415
+ test_btn.click(_test, inputs=host_in, outputs=[models_dd, models_dd, status])
416
+
417
+ # Refresh models
418
+ refresh_models_btn.click(_test, inputs=host_in, outputs=[models_dd, models_dd, status])
419
+
420
+ # Pull model progress
421
+ def _pull(h, name):
422
+ if not name:
423
+ yield "Please enter a model name to pull."
424
+ return
425
+ for line in pull_model(h, name.strip()):
426
+ yield line
427
+
428
+ pull_btn.click(_pull, inputs=[host_in, pull_model_txt], outputs=pull_log)
429
+
430
+ # Clear conversation
431
+ clear_btn.click(clear_conversation, outputs=[chatbot, state_convo, status])
432
+
433
+ # Export
434
+ export_file = gr.File(label="Download Conversation", visible=True)
435
+ export_btn.click(export_conversation, inputs=[state_history, state_convo], outputs=[export_file, status])
436
+
437
+ # Send/Stream
438
+ def _submit(
439
+ h, m, sp, t, tp, tk, rp, ctx, mx, sd, convo, history, text, files
440
+ ):
441
+ # Convert mx slider 0 -> None (auto)
442
+ mx_int = int(mx) if mx and int(mx) > 0 else None
443
+ sd_int = int(sd) if sd is not None else None
444
+ yield from stream_chat(
445
+ host=h,
446
+ model=m or DEFAULT_MODEL,
447
+ system_prompt=sp or "",
448
+ temperature=float(t),
449
+ top_p=float(tp),
450
+ top_k=int(tk),
451
+ repeat_penalty=float(rp),
452
+ num_ctx=int(ctx),
453
+ max_tokens=mx_int,
454
+ seed=sd_int,
455
+ convo_messages=convo,
456
+ chatbot_history=history,
457
+ user_text=text,
458
+ image_files=files,
459
+ )
460
+
461
+ submit_event = send_btn.click(
462
+ _submit,
463
+ inputs=[host_in, models_dd, sys_prompt, temperature, top_p, top_k, repeat_penalty, num_ctx, max_tokens, seed, state_convo, state_history, txt, image_files],
464
+ outputs=[chatbot, status, state_convo],
465
+ )
466
+ # Pressing Enter in the textbox also triggers submit
467
+ txt.submit(
468
+ _submit,
469
+ inputs=[host_in, models_dd, sys_prompt, temperature, top_p, top_k, repeat_penalty, num_ctx, max_tokens, seed, state_convo, state_history, txt, image_files],
470
+ outputs=[chatbot, status, state_convo],
471
+ )
472
+
473
+ # Stop streaming
474
+ stop_btn.click(None, None, None, cancels=[submit_event])
475
+
476
+ # After successful send, clear the input box and keep images cleared
477
+ def _post_send():
478
+ return "", None
479
+
480
+ send_btn.click(_post_send, outputs=[txt, image_files])
481
+ txt.submit(_post_send, outputs=[txt, image_files])
482
+
483
+ # Keep Chatbot state in sync (so export works)
484
+ def _sync_chatbot_state(history):
485
+ return history
486
+
487
+ chatbot.change(_sync_chatbot_state, inputs=chatbot, outputs=state_history)
488
+
489
+ return demo
490
+
491
+
492
+ if __name__ == "__main__":
493
+ demo = ui()
494
+ demo.queue(default_concurrency_limit=10)
495
+ demo.launch(server_name="0.0.0.0", server_port=DEFAULT_PORT, show_api=False)
496
+
497
+