2001haitem commited on
Commit
e7e14b6
·
verified ·
1 Parent(s): b192a75

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +35 -15
app.py CHANGED
@@ -2,15 +2,14 @@ import gradio as gr
2
  from transformers import pipeline
3
  import torch
4
  import os
5
- from pathlib import Path
6
 
7
  # ==============================
8
  # ⚙️ Configuration
9
  # ==============================
10
- # Use a smaller zero-shot model to avoid memory/storage limits
11
  MODEL_NAME = os.getenv("MODEL_NAME", "MoritzLaurer/deberta-v3-base-zeroshot-v2.0")
12
 
13
- # Optional: specify cache dir outside your repo
14
  HF_HOME = os.getenv("HF_HOME", "/tmp/huggingface")
15
  os.environ["TRANSFORMERS_CACHE"] = HF_HOME
16
 
@@ -33,28 +32,46 @@ except Exception as e:
33
  # ==============================
34
  # 🧠 Classification Function
35
  # ==============================
36
- def classify_severity(side_effects_str: str):
37
- if not side_effects_str or not side_effects_str.strip():
38
- return {}
39
-
40
- side_effects_list = [se.strip() for se in side_effects_str.split(",") if se.strip()]
41
- candidate_labels = ["Mild", "Moderate", "Severe", "Life-threatening", "Death"]
42
-
43
  try:
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
44
  results = severity_classifier(
45
  sequences=side_effects_list,
46
  candidate_labels=candidate_labels,
47
  multi_label=False
48
  )
49
 
 
50
  if isinstance(results, dict):
51
  results = [results]
52
 
53
- return {
 
54
  result["sequence"]: result["labels"][0].lower()
55
  for result in results
56
  }
57
 
 
 
58
  except Exception as e:
59
  return {"error": str(e)}
60
 
@@ -63,10 +80,12 @@ def classify_severity(side_effects_str: str):
63
  # ==============================
64
  demo = gr.Interface(
65
  fn=classify_severity,
66
- inputs=gr.Textbox(
67
- label="Side Effects",
68
- placeholder="Enter side effects, separated by commas (e.g., headache, nausea, dizziness)"
69
- ),
 
 
70
  outputs=gr.JSON(label="Severity Classification"),
71
  title="⚕️ Side Effect Severity Classifier",
72
  description="Classifies side effect severity using a DeBERTa/BART model for zero-shot classification.",
@@ -86,6 +105,7 @@ if __name__ == "__main__":
86
 
87
 
88
 
 
89
  # import gradio as gr
90
  # from transformers import pipeline, AutoModelForSequenceClassification, AutoTokenizer
91
  # from pathlib import Path
 
2
  from transformers import pipeline
3
  import torch
4
  import os
5
+ from typing import Union, List
6
 
7
  # ==============================
8
  # ⚙️ Configuration
9
  # ==============================
 
10
  MODEL_NAME = os.getenv("MODEL_NAME", "MoritzLaurer/deberta-v3-base-zeroshot-v2.0")
11
 
12
+ # Optional: specify cache dir
13
  HF_HOME = os.getenv("HF_HOME", "/tmp/huggingface")
14
  os.environ["TRANSFORMERS_CACHE"] = HF_HOME
15
 
 
32
  # ==============================
33
  # 🧠 Classification Function
34
  # ==============================
35
+ def classify_severity(side_effects_input: Union[str, List[str]]):
36
+ """
37
+ Accepts a string (comma-separated) or a list of side effects.
38
+ Returns a dictionary mapping side effects -> severity.
39
+ """
 
 
40
  try:
41
+ # Handle both list and string input types
42
+ if isinstance(side_effects_input, str):
43
+ side_effects_list = [
44
+ se.strip() for se in side_effects_input.split(",") if se.strip()
45
+ ]
46
+ elif isinstance(side_effects_input, list):
47
+ side_effects_list = [se.strip() for se in side_effects_input if se.strip()]
48
+ else:
49
+ return {"error": "Invalid input format. Expected string or list."}
50
+
51
+ if not side_effects_list:
52
+ return {}
53
+
54
+ candidate_labels = ["Mild", "Moderate", "Severe", "Life-threatening", "Death"]
55
+
56
+ # Perform classification
57
  results = severity_classifier(
58
  sequences=side_effects_list,
59
  candidate_labels=candidate_labels,
60
  multi_label=False
61
  )
62
 
63
+ # Normalize single result into list
64
  if isinstance(results, dict):
65
  results = [results]
66
 
67
+ # Convert to dict { side_effect: severity }
68
+ output_dict = {
69
  result["sequence"]: result["labels"][0].lower()
70
  for result in results
71
  }
72
 
73
+ return output_dict
74
+
75
  except Exception as e:
76
  return {"error": str(e)}
77
 
 
80
  # ==============================
81
  demo = gr.Interface(
82
  fn=classify_severity,
83
+ inputs=[
84
+ gr.Textbox(
85
+ label="Side Effects",
86
+ placeholder="Enter side effects, separated by commas (e.g., headache, nausea, dizziness)"
87
+ )
88
+ ],
89
  outputs=gr.JSON(label="Severity Classification"),
90
  title="⚕️ Side Effect Severity Classifier",
91
  description="Classifies side effect severity using a DeBERTa/BART model for zero-shot classification.",
 
105
 
106
 
107
 
108
+
109
  # import gradio as gr
110
  # from transformers import pipeline, AutoModelForSequenceClassification, AutoTokenizer
111
  # from pathlib import Path