CatoG commited on
Commit
a8d3f6b
·
unverified ·
1 Parent(s): c021211

Implement DPO model training and preference handling

Browse files

Added model loading, preference collection, and training functionalities using DPO for tuning models.

Files changed (1) hide show
  1. app.py +825 -1
app.py CHANGED
@@ -1 +1,825 @@
1
- app.py
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ from typing import List, Dict
3
+ from datetime import datetime
4
+
5
+ import torch
6
+ from torch import nn
7
+
8
+ import gradio as gr
9
+ import pandas as pd
10
+
11
+ from datasets import Dataset
12
+
13
+ from transformers import (
14
+ AutoModelForCausalLM,
15
+ AutoTokenizer,
16
+ GenerationConfig,
17
+ )
18
+
19
+ from peft import LoraConfig, get_peft_model
20
+ from trl import DPOConfig, DPOTrainer
21
+
22
+
23
+ # =========================================================
24
+ # MODEL LIST (from your BIAS demo)
25
+ # =========================================================
26
+
27
+ MODEL_CHOICES = [
28
+ # Very small / light (good for CPU Spaces)
29
+ "distilgpt2",
30
+ "gpt2",
31
+ "sshleifer/tiny-gpt2",
32
+ "LiquidAI/LFM2-350M",
33
+ "google/gemma-3-270m-it",
34
+ "Qwen/Qwen2.5-0.5B-Instruct",
35
+ "mkurman/NeuroBLAST-V3-SYNTH-EC-150000",
36
+
37
+ # Small–medium (~1–2B) – still reasonable on CPU, just slower
38
+ "TinyLlama/TinyLlama-1.1B-Chat-v1.0",
39
+ "google/gemma-3-1b-it",
40
+ "meta-llama/Llama-3.2-1B",
41
+ "litert-community/Gemma3-1B-IT",
42
+ "nvidia/Nemotron-Flash-1B",
43
+ "WeiboAI/VibeThinker-1.5B",
44
+ "Qwen/Qwen3-1.7B",
45
+
46
+ # Medium (~2–3B) – probably OK on beefier CPU / small GPU
47
+ "google/gemma-2-2b-it",
48
+ "thu-pacman/PCMind-2.1-Kaiyuan-2B",
49
+ "opendatalab/MinerU-HTML",
50
+ "ministral/Ministral-3b-instruct",
51
+ "HuggingFaceTB/SmolLM3-3B",
52
+ "meta-llama/Llama-3.2-3B-Instruct",
53
+ "nvidia/Nemotron-Flash-3B-Instruct",
54
+ "Qwen/Qwen2.5-3B-Instruct",
55
+
56
+ # Heavier (4–8B) – you really want a GPU Space for these
57
+ "Qwen/Qwen3-4B",
58
+ "Qwen/Qwen3-4B-Thinking-2507",
59
+ "Qwen/Qwen3-4B-Instruct-2507",
60
+ "mistralai/Mistral-7B-Instruct-v0.2",
61
+ "allenai/Olmo-3-7B-Instruct",
62
+ "Qwen/Qwen2.5-7B-Instruct",
63
+ "meta-llama/Meta-Llama-3-8B-Instruct",
64
+ "meta-llama/Llama-3.1-8B",
65
+ "meta-llama/Llama-3.1-8B-Instruct",
66
+ "openbmb/MiniCPM4.1-8B",
67
+ "deepseek-ai/DeepSeek-R1-Distill-Llama-8B",
68
+ "rl-research/DR-Tulu-8B",
69
+ ]
70
+
71
+ DEFAULT_MODEL = "Qwen/Qwen2.5-0.5B-Instruct"
72
+ TRAINED_MODEL_DIR = "trained_model"
73
+
74
+
75
+ # =========================================================
76
+ # GLOBALS & CONFIG
77
+ # =========================================================
78
+
79
+ device = "cuda" if torch.cuda.is_available() else "cpu"
80
+
81
+ tokenizer = None
82
+ policy_model = None
83
+ ref_model = None
84
+
85
+ DEFAULT_DPO_CONFIG = DPOConfig(
86
+ beta=0.1,
87
+ output_dir="dpo_demo",
88
+ num_train_epochs=1,
89
+ per_device_train_batch_size=1,
90
+ per_device_eval_batch_size=1,
91
+ remove_unused_columns=False,
92
+ logging_steps=1,
93
+ gradient_accumulation_steps=1,
94
+ learning_rate=1e-4,
95
+ evaluation_strategy="no", # warning is fine with current versions
96
+ warmup_steps=0,
97
+ fp16=False,
98
+ save_steps=0,
99
+ report_to="none",
100
+ )
101
+
102
+
103
+ # =========================================================
104
+ # LORA TARGET-MODULE HELPER
105
+ # =========================================================
106
+
107
+ def guess_lora_target_modules(model_name: str, base_model) -> List[str]:
108
+ """
109
+ Heuristically choose good LoRA target modules based on the model type/name.
110
+ - GPT-2-like: use c_attn/c_proj
111
+ - LLaMA/Gemma/Mistral/Qwen/etc: use q/k/v/o + MLP projections
112
+ - Fallback: scan Linear module names for known patterns
113
+ """
114
+ model_type = getattr(base_model.config, "model_type", "") or ""
115
+ name_lower = model_name.lower()
116
+
117
+ # GPT-2 / DistilGPT-2 / Tiny GPT-2
118
+ if (
119
+ "gpt2" in model_type
120
+ or "gpt2" in name_lower
121
+ or "tiny-gpt2" in name_lower
122
+ or "distilgpt2" in name_lower
123
+ ):
124
+ return ["c_attn", "c_proj"]
125
+
126
+ # LLaMA / Gemma / Mistral / Qwen / Olmo / MiniCPM / SmolLM / Nemotron etc.
127
+ if any(
128
+ t in model_type
129
+ for t in [
130
+ "llama",
131
+ "gemma",
132
+ "mistral",
133
+ "qwen",
134
+ "qwen2",
135
+ "olmo",
136
+ "minicpm",
137
+ "smollm",
138
+ "nemotron",
139
+ ]
140
+ ):
141
+ return ["q_proj", "k_proj", "v_proj", "o_proj", "gate_proj", "up_proj", "down_proj"]
142
+
143
+ # Fallback: inspect Linear modules and see what’s there
144
+ linear_leaf_names = []
145
+ for name, module in base_model.named_modules():
146
+ if isinstance(module, nn.Linear):
147
+ linear_leaf_names.append(name.split(".")[-1])
148
+
149
+ candidates = [
150
+ "q_proj", "k_proj", "v_proj", "o_proj",
151
+ "gate_proj", "up_proj", "down_proj",
152
+ "c_attn", "c_proj",
153
+ ]
154
+ found = sorted(set(n for n in candidates if n in linear_leaf_names))
155
+ if found:
156
+ return found
157
+
158
+ # If absolutely nothing matches, bail with a clear error
159
+ raise ValueError(
160
+ f"Could not guess LoRA target modules for model '{model_name}' "
161
+ f"(model_type='{model_type}'). "
162
+ f"Try setting target_modules manually for this model."
163
+ )
164
+
165
+
166
+ # =========================================================
167
+ # MODEL LOADING
168
+ # =========================================================
169
+
170
+ def load_base_model(model_name: str) -> str:
171
+ """
172
+ Load tokenizer + base model, then create:
173
+ - policy_model: LoRA-adapted (trainable)
174
+ - ref_model: frozen base model for DPO
175
+ """
176
+ global tokenizer, policy_model, ref_model
177
+
178
+ tokenizer = AutoTokenizer.from_pretrained(
179
+ model_name,
180
+ trust_remote_code=True,
181
+ )
182
+ if tokenizer.pad_token is None:
183
+ tokenizer.pad_token = tokenizer.eos_token
184
+ tokenizer.padding_side = "right"
185
+
186
+ base_model = AutoModelForCausalLM.from_pretrained(
187
+ model_name,
188
+ trust_remote_code=True,
189
+ )
190
+ base_model.config.use_cache = False
191
+ base_model.config.pad_token_id = tokenizer.eos_token_id
192
+
193
+ # Choose LoRA target modules dynamically
194
+ target_modules = guess_lora_target_modules(model_name, base_model)
195
+
196
+ peft_config = LoraConfig(
197
+ r=4,
198
+ target_modules=target_modules,
199
+ task_type="CAUSAL_LM",
200
+ lora_alpha=8,
201
+ lora_dropout=0.1,
202
+ bias="none",
203
+ )
204
+
205
+ # Policy model = base + LoRA (trainable)
206
+ policy = get_peft_model(base_model, peft_config)
207
+ policy.to(device)
208
+ policy.eval()
209
+
210
+ # Reference model = frozen base model
211
+ reference = AutoModelForCausalLM.from_pretrained(
212
+ model_name,
213
+ trust_remote_code=True,
214
+ )
215
+ reference.config.use_cache = False
216
+ reference.config.pad_token_id = tokenizer.eos_token_id
217
+ reference.to(device)
218
+ for p in reference.parameters():
219
+ p.requires_grad = False
220
+ reference.eval()
221
+
222
+ policy_model = policy
223
+ ref_model = reference
224
+
225
+ return (
226
+ f"Loaded base model: **{model_name}** on **{device}** "
227
+ f"with LoRA target_modules={target_modules}"
228
+ )
229
+
230
+
231
+ # Load default on startup
232
+ initial_status = load_base_model(DEFAULT_MODEL)
233
+
234
+
235
+ # =========================================================
236
+ # UTILS
237
+ # =========================================================
238
+
239
+ def build_generation_config(
240
+ do_sample: bool,
241
+ temperature: float,
242
+ max_new_tokens: int,
243
+ top_k: int = 20,
244
+ top_p: float = 0.9,
245
+ ) -> GenerationConfig:
246
+ """
247
+ Helper to build a GenerationConfig from UI settings.
248
+ """
249
+ # Clamp values a bit just to be safe
250
+ temperature = max(0.0, float(temperature))
251
+ max_new_tokens = int(max_new_tokens)
252
+ return GenerationConfig(
253
+ do_sample=bool(do_sample),
254
+ temperature=temperature,
255
+ top_k=top_k,
256
+ top_p=top_p,
257
+ max_new_tokens=max_new_tokens,
258
+ pad_token_id=tokenizer.eos_token_id,
259
+ )
260
+
261
+
262
+ def generate_text(
263
+ model: nn.Module,
264
+ prompt: str,
265
+ gen_config: GenerationConfig,
266
+ style_prefix: str = "",
267
+ ) -> str:
268
+ model.eval()
269
+ full_prompt = style_prefix + prompt
270
+
271
+ inputs = tokenizer(
272
+ full_prompt,
273
+ return_tensors="pt",
274
+ padding=False,
275
+ ).to(device)
276
+
277
+ with torch.no_grad():
278
+ outputs = model.generate(
279
+ **inputs,
280
+ do_sample=gen_config.do_sample,
281
+ top_k=gen_config.top_k,
282
+ top_p=gen_config.top_p,
283
+ temperature=gen_config.temperature,
284
+ max_new_tokens=gen_config.max_new_tokens,
285
+ pad_token_id=gen_config.pad_token_id,
286
+ )
287
+
288
+ text = tokenizer.decode(outputs[0], skip_special_tokens=True)
289
+ if text.startswith(full_prompt):
290
+ return text[len(full_prompt):].strip()
291
+ return text.strip()
292
+
293
+
294
+ def preferences_to_df(preferences: List[Dict]) -> pd.DataFrame:
295
+ if not preferences:
296
+ return pd.DataFrame(columns=["prompt", "chosen", "rejected"])
297
+ return pd.DataFrame(preferences)
298
+
299
+
300
+ def list_trained_model_files() -> List[str]:
301
+ """
302
+ Return a list of filepaths under TRAINED_MODEL_DIR (for download).
303
+ """
304
+ if not os.path.isdir(TRAINED_MODEL_DIR):
305
+ return []
306
+ files: List[str] = []
307
+ for root, dirs, filenames in os.walk(TRAINED_MODEL_DIR):
308
+ for name in filenames:
309
+ files.append(os.path.join(root, name))
310
+ return files
311
+
312
+
313
+ # =========================================================
314
+ # DPO CALLBACKS
315
+ # =========================================================
316
+
317
+ def generate_candidates(
318
+ prompt: str,
319
+ do_sample: bool,
320
+ temperature: float,
321
+ max_new_tokens: int,
322
+ ) -> tuple[str, str]:
323
+ """
324
+ Generate Answer A (balanced) and Answer B (creative-ish),
325
+ using the same core generation settings from the GUI.
326
+ """
327
+ if not prompt.strip():
328
+ return "", ""
329
+
330
+ # Build two configs from the same UI settings,
331
+ # but make B slightly more "wild" by bumping top_k / temperature a bit
332
+ balanced_config = build_generation_config(
333
+ do_sample=do_sample,
334
+ temperature=temperature,
335
+ max_new_tokens=max_new_tokens,
336
+ top_k=20,
337
+ top_p=0.9,
338
+ )
339
+
340
+ # For creative answer, nudge temperature and top_k a bit, but still
341
+ # keep them tied to UI settings.
342
+ creative_temp = float(temperature) + 0.4
343
+ creative_config = build_generation_config(
344
+ do_sample=do_sample,
345
+ temperature=creative_temp,
346
+ max_new_tokens=max_new_tokens,
347
+ top_k=50,
348
+ top_p=0.95,
349
+ )
350
+
351
+ style_balanced = (
352
+ "You are a helpful, careful assistant. "
353
+ "Answer clearly and sensibly.\n\nUser: "
354
+ )
355
+ style_creative = (
356
+ "You are a creative assistant who explores unusual ideas and stronger opinions, "
357
+ "while still staying safe.\n\nUser: "
358
+ )
359
+
360
+ answer_a = generate_text(
361
+ policy_model,
362
+ prompt,
363
+ balanced_config,
364
+ style_prefix=style_balanced,
365
+ )
366
+ answer_b = generate_text(
367
+ policy_model,
368
+ prompt,
369
+ creative_config,
370
+ style_prefix=style_creative,
371
+ )
372
+
373
+ return answer_a, answer_b
374
+
375
+
376
+ def save_preference(
377
+ prompt: str,
378
+ answer_a: str,
379
+ answer_b: str,
380
+ custom_answer: str,
381
+ preference_mode: str,
382
+ state_preferences: List[Dict],
383
+ ):
384
+ """
385
+ Encode a preference in one of four ways:
386
+ - Prefer A over B -> chosen=A, rejected=B
387
+ - Prefer B over A -> chosen=B, rejected=A
388
+ - Prefer custom over A -> chosen=custom, rejected=A
389
+ - Prefer custom over B -> chosen=custom, rejected=B
390
+ """
391
+ msg = ""
392
+
393
+ if not prompt.strip():
394
+ msg = "No prompt provided."
395
+ return state_preferences, preferences_to_df(state_preferences), msg
396
+
397
+ if not answer_a.strip() or not answer_b.strip():
398
+ msg = "Generate both model answers before saving a preference."
399
+ return state_preferences, preferences_to_df(state_preferences), msg
400
+
401
+ if not preference_mode:
402
+ msg = "Please choose how to encode the preference."
403
+ return state_preferences, preferences_to_df(state_preferences), msg
404
+
405
+ preference_mode = preference_mode.strip()
406
+
407
+ chosen = None
408
+ rejected = None
409
+
410
+ if preference_mode == "Prefer A over B":
411
+ chosen = answer_a
412
+ rejected = answer_b
413
+
414
+ elif preference_mode == "Prefer B over A":
415
+ chosen = answer_b
416
+ rejected = answer_a
417
+
418
+ elif preference_mode == "Prefer custom over A":
419
+ if not custom_answer.strip():
420
+ msg = "You selected 'Prefer custom over A' but did not provide a custom answer."
421
+ return state_preferences, preferences_to_df(state_preferences), msg
422
+ chosen = custom_answer
423
+ rejected = answer_a
424
+
425
+ elif preference_mode == "Prefer custom over B":
426
+ if not custom_answer.strip():
427
+ msg = "You selected 'Prefer custom over B' but did not provide a custom answer."
428
+ return state_preferences, preferences_to_df(state_preferences), msg
429
+ chosen = custom_answer
430
+ rejected = answer_b
431
+
432
+ else:
433
+ msg = f"Unknown preference mode: {preference_mode}"
434
+ return state_preferences, preferences_to_df(state_preferences), msg
435
+
436
+ entry = {
437
+ "prompt": prompt.strip(),
438
+ "chosen": chosen.strip(),
439
+ "rejected": rejected.strip(),
440
+ }
441
+
442
+ state_preferences = list(state_preferences) + [entry]
443
+ df = preferences_to_df(state_preferences)
444
+ msg = f"Saved preference #{len(state_preferences)}."
445
+
446
+ return state_preferences, df, msg
447
+
448
+
449
+ def train_dpo_model(
450
+ state_preferences: List[Dict],
451
+ num_epochs: int,
452
+ learning_rate: float,
453
+ beta: float,
454
+ progress=gr.Progress(track_tqdm=True),
455
+ ):
456
+ """
457
+ Run DPO training on the accumulated preferences.
458
+ Shows a progress bar/spinner and returns:
459
+ - a detailed status message
460
+ - a 'last trained' timestamp string
461
+ - a list of saved model files for download
462
+ """
463
+ global policy_model, ref_model
464
+
465
+ progress(0.0, desc="Checking preferences...")
466
+
467
+ if not state_preferences:
468
+ return (
469
+ "⚠️ No preferences collected yet. Add some first.",
470
+ "**Last trained:** never",
471
+ [],
472
+ )
473
+
474
+ dataset = Dataset.from_list(state_preferences)
475
+
476
+ progress(0.2, desc="Configuring DPO trainer...")
477
+
478
+ dpo_config = DPOConfig(
479
+ **{
480
+ **DEFAULT_DPO_CONFIG.to_dict(),
481
+ "num_train_epochs": int(num_epochs),
482
+ "learning_rate": float(learning_rate),
483
+ "beta": float(beta),
484
+ }
485
+ )
486
+
487
+ trainer = DPOTrainer(
488
+ model=policy_model,
489
+ ref_model=ref_model,
490
+ args=dpo_config,
491
+ train_dataset=dataset,
492
+ eval_dataset=None,
493
+ tokenizer=tokenizer,
494
+ max_length=256,
495
+ )
496
+
497
+ progress(0.4, desc="Training model with DPO...")
498
+
499
+ trainer.train()
500
+
501
+ progress(0.75, desc="Finalizing and moving model to device...")
502
+
503
+ policy_model = trainer.model
504
+ policy_model.to(device)
505
+ policy_model.eval()
506
+
507
+ # Save the trained model + tokenizer so you can download them
508
+ progress(0.9, desc="Saving trained model to disk...")
509
+
510
+ os.makedirs(TRAINED_MODEL_DIR, exist_ok=True)
511
+ policy_model.save_pretrained(TRAINED_MODEL_DIR)
512
+ tokenizer.save_pretrained(TRAINED_MODEL_DIR)
513
+
514
+ files = list_trained_model_files()
515
+
516
+ progress(1.0, desc="Done")
517
+
518
+ n = len(state_preferences)
519
+ finished_at = datetime.now().strftime("%Y-%m-%d %H:%M:%S")
520
+
521
+ msg = f"""### ✅ Training complete
522
+
523
+ - Preference pairs used: **{n}**
524
+ - Epochs: **{num_epochs}**
525
+ - Learning rate: **{learning_rate}**
526
+ - DPO beta (strength): **{beta}**
527
+
528
+ The tuned policy model + tokenizer have been saved to `{TRAINED_MODEL_DIR}/`.
529
+ You can download them using the file list below.
530
+ """
531
+
532
+ last_trained_msg = f"**Last trained:** {finished_at}"
533
+
534
+ return msg, last_trained_msg, files
535
+
536
+
537
+ def generate_from_aligned_model(
538
+ prompt: str,
539
+ do_sample: bool,
540
+ temperature: float,
541
+ max_new_tokens: int,
542
+ ) -> str:
543
+ if not prompt.strip():
544
+ return ""
545
+ gen_config = build_generation_config(
546
+ do_sample=do_sample,
547
+ temperature=temperature,
548
+ max_new_tokens=max_new_tokens,
549
+ top_k=20,
550
+ top_p=0.9,
551
+ )
552
+ style_balanced = (
553
+ "You are a helpful, careful assistant. "
554
+ "Answer clearly and sensibly.\n\nUser: "
555
+ )
556
+ return generate_text(
557
+ policy_model,
558
+ prompt,
559
+ gen_config,
560
+ style_prefix=style_balanced,
561
+ )
562
+
563
+
564
+ def on_model_change(
565
+ model_name: str,
566
+ _state_preferences: List[Dict],
567
+ ):
568
+ """
569
+ When the user picks a new base model:
570
+ - reload tokenizer + policy_model + ref_model
571
+ - clear collected preferences (since they belong to previous model)
572
+ - reset training status, 'last trained', and download list
573
+ """
574
+ status = load_base_model(model_name)
575
+ empty_prefs: List[Dict] = []
576
+ df = preferences_to_df(empty_prefs)
577
+ reset_msg = (
578
+ status
579
+ + "\n\nPreferences cleared (new model = new preference data)."
580
+ )
581
+ last_trained_reset = "**Last trained:** (reset for new base model)"
582
+ files_reset: List[str] = []
583
+ # returns: model_status, prefs, pref_table_df, train_status, last_trained, files
584
+ return reset_msg, empty_prefs, df, "", last_trained_reset, files_reset
585
+
586
+
587
+ # =========================================================
588
+ # GRADIO UI
589
+ # =========================================================
590
+
591
+ with gr.Blocks() as demo:
592
+ gr.Markdown(
593
+ """
594
+ # 🔧 DPO Playground – Preference Tuning on Different Models
595
+
596
+ - Pick a **base model** from the dropdown.
597
+ - Ask a question and generate two answers:
598
+ - **A** = balanced / normal
599
+ - **B** = creative / more extreme
600
+ - Optionally write **your own ideal answer**.
601
+ - Choose how to encode the preference (e.g. A over B, custom over A, etc.).
602
+ - Collect several preferences and **train the model with DPO**.
603
+ - Test how the aligned policy model behaves on new prompts.
604
+ - Download the tuned model (LoRA adapter + tokenizer) after training.
605
+ - **Control temperature, sampling, and max_new_tokens directly in the UI.**
606
+ """
607
+ )
608
+
609
+ state_preferences = gr.State([])
610
+
611
+ with gr.Row():
612
+ model_dropdown = gr.Dropdown(
613
+ choices=MODEL_CHOICES,
614
+ value=DEFAULT_MODEL,
615
+ label="Base model",
616
+ )
617
+
618
+ model_status = gr.Markdown(initial_status)
619
+
620
+ # -----------------------------------------------------
621
+ # Collect preferences tab
622
+ # -----------------------------------------------------
623
+ with gr.Tab("Collect preferences"):
624
+ with gr.Row():
625
+ prompt_input = gr.Textbox(
626
+ label="Prompt",
627
+ placeholder="Ask anything...",
628
+ lines=3,
629
+ )
630
+
631
+ gr.Markdown("### Generation settings for Answer A & B")
632
+
633
+ with gr.Row():
634
+ gen_do_sample = gr.Checkbox(
635
+ value=True,
636
+ label="Use sampling (do_sample)",
637
+ )
638
+ gen_temperature = gr.Slider(
639
+ minimum=0.0,
640
+ maximum=1.5,
641
+ value=0.8,
642
+ step=0.05,
643
+ label="Temperature",
644
+ )
645
+ gen_max_new_tokens = gr.Slider(
646
+ minimum=4,
647
+ maximum=256,
648
+ value=128,
649
+ step=4,
650
+ label="Max new tokens",
651
+ )
652
+
653
+ generate_btn = gr.Button("Generate A & B")
654
+
655
+ with gr.Row():
656
+ answer_a_box = gr.Textbox(
657
+ label="Answer A (balanced / normal)",
658
+ lines=8,
659
+ )
660
+ answer_b_box = gr.Textbox(
661
+ label="Answer B (creative / more extreme)",
662
+ lines=8,
663
+ )
664
+
665
+ custom_answer_box = gr.Textbox(
666
+ label="Your own ideal answer (optional)",
667
+ lines=8,
668
+ placeholder="If you want, write the answer you *wish* the model had given.",
669
+ )
670
+
671
+ preference_mode = gr.Radio(
672
+ choices=[
673
+ "Prefer A over B",
674
+ "Prefer B over A",
675
+ "Prefer custom over A",
676
+ "Prefer custom over B",
677
+ ],
678
+ label="How should this preference be encoded?",
679
+ )
680
+
681
+ save_pref_btn = gr.Button("Save preference")
682
+
683
+ pref_status = gr.Markdown("")
684
+ pref_table = gr.Dataframe(
685
+ headers=["prompt", "chosen", "rejected"],
686
+ label="Collected preferences (for DPO training)",
687
+ wrap=True,
688
+ )
689
+
690
+ generate_btn.click(
691
+ fn=generate_candidates,
692
+ inputs=[prompt_input, gen_do_sample, gen_temperature, gen_max_new_tokens],
693
+ outputs=[answer_a_box, answer_b_box],
694
+ )
695
+
696
+ save_pref_btn.click(
697
+ fn=save_preference,
698
+ inputs=[
699
+ prompt_input,
700
+ answer_a_box,
701
+ answer_b_box,
702
+ custom_answer_box,
703
+ preference_mode,
704
+ state_preferences,
705
+ ],
706
+ outputs=[
707
+ state_preferences,
708
+ pref_table,
709
+ pref_status,
710
+ ],
711
+ )
712
+
713
+ # -----------------------------------------------------
714
+ # Train & test tab
715
+ # -----------------------------------------------------
716
+ with gr.Tab("Train & test DPO model"):
717
+ gr.Markdown(
718
+ "Train the LoRA-adapted policy model using your preferences "
719
+ "with **Direct Preference Optimization (DPO)**."
720
+ )
721
+
722
+ with gr.Row():
723
+ num_epochs_slider = gr.Slider(
724
+ minimum=1,
725
+ maximum=5,
726
+ step=1,
727
+ value=1,
728
+ label="Number of epochs",
729
+ )
730
+ lr_slider = gr.Slider(
731
+ minimum=1e-5,
732
+ maximum=5e-4,
733
+ step=1e-5,
734
+ value=1e-4,
735
+ label="Learning rate",
736
+ )
737
+ beta_slider = gr.Slider(
738
+ minimum=0.05,
739
+ maximum=0.5,
740
+ step=0.05,
741
+ value=0.1,
742
+ label="DPO beta (strength)",
743
+ )
744
+
745
+ train_btn = gr.Button("Train DPO model", variant="primary")
746
+ train_status = gr.Markdown("")
747
+ last_trained = gr.Markdown("**Last trained:** never")
748
+
749
+ download_files = gr.Files(
750
+ label="Trained model files (adapter + tokenizer)",
751
+ interactive=False,
752
+ )
753
+
754
+ train_btn.click(
755
+ fn=train_dpo_model,
756
+ inputs=[
757
+ state_preferences,
758
+ num_epochs_slider,
759
+ lr_slider,
760
+ beta_slider,
761
+ ],
762
+ outputs=[train_status, last_trained, download_files],
763
+ )
764
+
765
+ gr.Markdown("## Try the current policy model")
766
+
767
+ with gr.Row():
768
+ test_do_sample = gr.Checkbox(
769
+ value=False,
770
+ label="Use sampling (do_sample) for test",
771
+ )
772
+ test_temperature = gr.Slider(
773
+ minimum=0.0,
774
+ maximum=1.5,
775
+ value=0.0,
776
+ step=0.05,
777
+ label="Temperature (test)",
778
+ )
779
+ test_max_new_tokens = gr.Slider(
780
+ minimum=4,
781
+ maximum=256,
782
+ value=64,
783
+ step=4,
784
+ label="Max new tokens (test)",
785
+ )
786
+
787
+ test_prompt = gr.Textbox(
788
+ label="Test prompt",
789
+ placeholder="Ask something to see the aligned model...",
790
+ lines=3,
791
+ )
792
+ test_btn = gr.Button("Generate from DPO policy model")
793
+ test_answer = gr.Textbox(
794
+ label="Policy model answer",
795
+ lines=8,
796
+ )
797
+
798
+ test_btn.click(
799
+ fn=generate_from_aligned_model,
800
+ inputs=[
801
+ test_prompt,
802
+ test_do_sample,
803
+ test_temperature,
804
+ test_max_new_tokens,
805
+ ],
806
+ outputs=test_answer,
807
+ )
808
+
809
+ # model change: reload + clear prefs + reset train status + last trained + downloads
810
+ model_dropdown.change(
811
+ fn=on_model_change,
812
+ inputs=[model_dropdown, state_preferences],
813
+ outputs=[
814
+ model_status,
815
+ state_preferences,
816
+ pref_table,
817
+ train_status,
818
+ last_trained,
819
+ download_files,
820
+ ],
821
+ )
822
+
823
+ if __name__ == "__main__":
824
+ demo.queue().launch()
825
+