CatoG commited on
Commit
0905744
Β·
unverified Β·
1 Parent(s): 2614f77

Add logprob_answer function and improve diagnostics

Browse files

Added logprob_answer function to compute log-probability of answers based on prompts. Enhanced DPO diagnostics for evaluating model preferences.

Files changed (1) hide show
  1. app.py +121 -9
app.py CHANGED
@@ -4,6 +4,7 @@ from datetime import datetime
4
 
5
  import torch
6
  from torch import nn
 
7
 
8
  import gradio as gr
9
  import pandas as pd
@@ -21,7 +22,7 @@ from trl import DPOConfig, DPOTrainer
21
 
22
 
23
  # =========================================================
24
- # MODEL LIST (from your BIAS demo)
25
  # =========================================================
26
 
27
  MODEL_CHOICES = [
@@ -92,7 +93,7 @@ DEFAULT_DPO_CONFIG = DPOConfig(
92
  logging_steps=1,
93
  gradient_accumulation_steps=1,
94
  learning_rate=1e-4,
95
- evaluation_strategy="no", # warning is fine with current versions
96
  warmup_steps=0,
97
  fp16=False,
98
  save_steps=0,
@@ -246,7 +247,6 @@ def build_generation_config(
246
  """
247
  Helper to build a GenerationConfig from UI settings.
248
  """
249
- # Clamp values a bit just to be safe
250
  temperature = max(0.0, float(temperature))
251
  max_new_tokens = int(max_new_tokens)
252
  return GenerationConfig(
@@ -310,6 +310,41 @@ def list_trained_model_files() -> List[str]:
310
  return files
311
 
312
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
313
  # =========================================================
314
  # DPO CALLBACKS
315
  # =========================================================
@@ -327,8 +362,6 @@ def generate_candidates(
327
  if not prompt.strip():
328
  return "", ""
329
 
330
- # Build two configs from the same UI settings,
331
- # but make B slightly more "wild" by bumping top_k / temperature a bit
332
  balanced_config = build_generation_config(
333
  do_sample=do_sample,
334
  temperature=temperature,
@@ -337,8 +370,6 @@ def generate_candidates(
337
  top_p=0.9,
338
  )
339
 
340
- # For creative answer, nudge temperature and top_k a bit, but still
341
- # keep them tied to UI settings.
342
  creative_temp = float(temperature) + 0.4
343
  creative_config = build_generation_config(
344
  do_sample=do_sample,
@@ -534,6 +565,76 @@ You can download them using the file list below.
534
  return msg, last_trained_msg, files
535
 
536
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
537
  def generate_from_aligned_model(
538
  prompt: str,
539
  do_sample: bool,
@@ -602,7 +703,8 @@ with gr.Blocks() as demo:
602
  - Collect several preferences and **train the model with DPO**.
603
  - Test how the aligned policy model behaves on new prompts.
604
  - Download the tuned model (LoRA adapter + tokenizer) after training.
605
- - **Control temperature, sampling, and max_new_tokens directly in the UI.**
 
606
  """
607
  )
608
 
@@ -806,6 +908,17 @@ with gr.Blocks() as demo:
806
  outputs=test_answer,
807
  )
808
 
 
 
 
 
 
 
 
 
 
 
 
809
  # model change: reload + clear prefs + reset train status + last trained + downloads
810
  model_dropdown.change(
811
  fn=on_model_change,
@@ -822,4 +935,3 @@ with gr.Blocks() as demo:
822
 
823
  if __name__ == "__main__":
824
  demo.queue().launch()
825
-
 
4
 
5
  import torch
6
  from torch import nn
7
+ import torch.nn.functional as F
8
 
9
  import gradio as gr
10
  import pandas as pd
 
22
 
23
 
24
  # =========================================================
25
+ # MODEL LIST
26
  # =========================================================
27
 
28
  MODEL_CHOICES = [
 
93
  logging_steps=1,
94
  gradient_accumulation_steps=1,
95
  learning_rate=1e-4,
96
+ evaluation_strategy="no",
97
  warmup_steps=0,
98
  fp16=False,
99
  save_steps=0,
 
247
  """
248
  Helper to build a GenerationConfig from UI settings.
249
  """
 
250
  temperature = max(0.0, float(temperature))
251
  max_new_tokens = int(max_new_tokens)
252
  return GenerationConfig(
 
310
  return files
311
 
312
 
313
+ def logprob_answer(
314
+ model: nn.Module,
315
+ tokenizer: AutoTokenizer,
316
+ prompt: str,
317
+ answer: str,
318
+ ) -> float:
319
+ """
320
+ Compute the log-probability of `answer` given `prompt`,
321
+ using a simple "User/Assistant" format:
322
+
323
+ full_text = "User: <prompt>\\nAssistant: <answer>"
324
+
325
+ We approximate p(answer | prompt) by summing log-probs of all tokens
326
+ in the answer region (the shared prompt part cancels in comparisons).
327
+ """
328
+ model.eval()
329
+ with torch.no_grad():
330
+ full_text = f"User: {prompt}\nAssistant: {answer}"
331
+ enc = tokenizer(
332
+ full_text,
333
+ return_tensors="pt",
334
+ ).to(device)
335
+
336
+ input_ids = enc["input_ids"]
337
+ out = model(input_ids=input_ids)
338
+ logits = out.logits[:, :-1, :] # [B, T-1, V]
339
+ labels = input_ids[:, 1:] # [B, T-1]
340
+
341
+ log_probs = F.log_softmax(logits, dim=-1)
342
+ token_log_probs = log_probs.gather(-1, labels.unsqueeze(-1)).squeeze(-1)
343
+ total_logprob = token_log_probs.sum().item()
344
+
345
+ return float(total_logprob)
346
+
347
+
348
  # =========================================================
349
  # DPO CALLBACKS
350
  # =========================================================
 
362
  if not prompt.strip():
363
  return "", ""
364
 
 
 
365
  balanced_config = build_generation_config(
366
  do_sample=do_sample,
367
  temperature=temperature,
 
370
  top_p=0.9,
371
  )
372
 
 
 
373
  creative_temp = float(temperature) + 0.4
374
  creative_config = build_generation_config(
375
  do_sample=do_sample,
 
565
  return msg, last_trained_msg, files
566
 
567
 
568
+ def dpo_diagnostics(state_preferences: List[Dict]) -> str:
569
+ """
570
+ Compute how often the policy_model and ref_model
571
+ assign higher log-probability to the CHOSEN answer
572
+ than to the REJECTED answer.
573
+
574
+ Returns a markdown report with:
575
+ - number of pairs
576
+ - policy win rate
577
+ - ref win rate
578
+ - average logprob margins
579
+ """
580
+ if not state_preferences:
581
+ return "No preferences collected yet – nothing to evaluate."
582
+
583
+ if policy_model is None or ref_model is None or tokenizer is None:
584
+ return "Models not loaded – reload base model first."
585
+
586
+ n = len(state_preferences)
587
+ policy_wins = 0
588
+ ref_wins = 0
589
+
590
+ policy_margins = []
591
+ ref_margins = []
592
+
593
+ for ex in state_preferences:
594
+ prompt = ex["prompt"]
595
+ chosen = ex["chosen"]
596
+ rejected = ex["rejected"]
597
+
598
+ # Policy model logprobs
599
+ lp_pol_ch = logprob_answer(policy_model, tokenizer, prompt, chosen)
600
+ lp_pol_rj = logprob_answer(policy_model, tokenizer, prompt, rejected)
601
+ margin_pol = lp_pol_ch - lp_pol_rj
602
+ policy_margins.append(margin_pol)
603
+ if margin_pol > 0:
604
+ policy_wins += 1
605
+
606
+ # Reference model logprobs
607
+ lp_ref_ch = logprob_answer(ref_model, tokenizer, prompt, chosen)
608
+ lp_ref_rj = logprob_answer(ref_model, tokenizer, prompt, rejected)
609
+ margin_ref = lp_ref_ch - lp_ref_rj
610
+ ref_margins.append(margin_ref)
611
+ if margin_ref > 0:
612
+ ref_wins += 1
613
+
614
+ policy_winrate = policy_wins / n
615
+ ref_winrate = ref_wins / n
616
+
617
+ avg_pol_margin = sum(policy_margins) / n
618
+ avg_ref_margin = sum(ref_margins) / n
619
+
620
+ report = f"""### πŸ“Š DPO Diagnostics
621
+
622
+ Preference pairs evaluated: **{n}**
623
+
624
+ **Policy model (after DPO)**
625
+ - Win rate (chosen > rejected): **{policy_winrate:.2%}**
626
+ - Avg logprob(chosen βˆ’ rejected): **{avg_pol_margin:.3f}**
627
+
628
+ **Reference model (base)**
629
+ - Win rate (chosen > rejected): **{ref_winrate:.2%}**
630
+ - Avg logprob(chosen βˆ’ rejected): **{avg_ref_margin:.3f}**
631
+
632
+ > A higher win rate and margin for the policy model compared to the reference model
633
+ > indicates that DPO training is successfully shifting the model toward your preferences.
634
+ """
635
+ return report
636
+
637
+
638
  def generate_from_aligned_model(
639
  prompt: str,
640
  do_sample: bool,
 
703
  - Collect several preferences and **train the model with DPO**.
704
  - Test how the aligned policy model behaves on new prompts.
705
  - Download the tuned model (LoRA adapter + tokenizer) after training.
706
+ - Use **DPO diagnostics** to see if the aligned model prefers your chosen answers
707
+ more often than the base model.
708
  """
709
  )
710
 
 
908
  outputs=test_answer,
909
  )
910
 
911
+ gr.Markdown("## πŸ“ˆ DPO diagnostics")
912
+
913
+ diag_btn = gr.Button("Compute preference win rates (policy vs base)")
914
+ diag_output = gr.Markdown("")
915
+
916
+ diag_btn.click(
917
+ fn=dpo_diagnostics,
918
+ inputs=[state_preferences],
919
+ outputs=[diag_output],
920
+ )
921
+
922
  # model change: reload + clear prefs + reset train status + last trained + downloads
923
  model_dropdown.change(
924
  fn=on_model_change,
 
935
 
936
  if __name__ == "__main__":
937
  demo.queue().launch()