Dusit-P commited on
Commit
5b4b5ca
·
verified ·
1 Parent(s): 4e9e3bb

Upload 2 files

Browse files
Files changed (2) hide show
  1. app.py +64 -0
  2. requirements.txt +5 -0
app.py ADDED
@@ -0,0 +1,64 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os, json, importlib.util, torch
2
+ import torch.nn.functional as F
3
+ import gradio as gr
4
+ from huggingface_hub import hf_hub_download
5
+ from safetensors.torch import load_file
6
+ from transformers import AutoTokenizer
7
+
8
+ # ===== ปรับได้ผ่าน Settings > Variables (Environment) =====
9
+ REPO_ID = os.getenv("REPO_ID", "Dusit-P/thai-sentiment-wcb")
10
+ DEFAULT_MODEL = os.getenv("DEFAULT_MODEL", "cnn_bilstm") # หรือ "baseline"
11
+ HF_TOKEN = os.getenv("HF_TOKEN", None) # ถ้าโมเดลเป็น private ให้เพิ่ม secret ชื่อนี้
12
+
13
+ CACHE = {}
14
+
15
+ def _import_models():
16
+ if "models_module" in CACHE:
17
+ return CACHE["models_module"]
18
+ models_py = hf_hub_download(REPO_ID, filename="common/models.py", token=HF_TOKEN)
19
+ spec = importlib.util.spec_from_file_location("models", models_py)
20
+ mod = importlib.util.module_from_spec(spec)
21
+ spec.loader.exec_module(mod)
22
+ CACHE["models_module"] = mod
23
+ return mod
24
+
25
+ def load_model(model_name: str):
26
+ key = f"model:{model_name}"
27
+ if key in CACHE:
28
+ return CACHE[key]
29
+ cfg_path = hf_hub_download(REPO_ID, filename=f"{model_name}/config.json", token=HF_TOKEN)
30
+ w_path = hf_hub_download(REPO_ID, filename=f"{model_name}/model.safetensors", token=HF_TOKEN)
31
+ with open(cfg_path, "r", encoding="utf-8") as f:
32
+ cfg = json.load(f)
33
+ models = _import_models()
34
+ tok = AutoTokenizer.from_pretrained(cfg["base_model"])
35
+ model = models.create_model_by_name(cfg["arch"])
36
+ state = load_file(w_path)
37
+ model.load_state_dict(state, strict=True)
38
+ model.eval()
39
+ CACHE[key] = (model, tok, cfg)
40
+ return CACHE[key]
41
+
42
+ def predict_api(text: str, model_choice: str):
43
+ if not text.strip():
44
+ return {"negative": 0.0, "positive": 0.0}, ""
45
+ model_name = "baseline" if model_choice == "baseline" else "cnn_bilstm"
46
+ model, tok, cfg = load_model(model_name)
47
+ enc = tok([text], padding=True, truncation=True, max_length=cfg["max_len"], return_tensors="pt")
48
+ with torch.no_grad():
49
+ logits = model(enc["input_ids"], enc["attention_mask"])
50
+ probs = F.softmax(logits, dim=1)[0].tolist()
51
+ out = {"negative": float(probs[0]), "positive": float(probs[1])}
52
+ label = "positive" if out["positive"] >= out["negative"] else "negative"
53
+ return out, label
54
+
55
+ with gr.Blocks(title="Thai Sentiment API (Dusit-P)") as demo:
56
+ gr.Markdown("### Thai Sentiment (WangchanBERTa + LSTM Heads)")
57
+ inp_text = gr.Textbox(lines=3, label="ข้อความรีวิวภาษาไทย", placeholder="พิมพ์รีวิวที่นี่")
58
+ inp_model = gr.Radio(choices=["cnn_bilstm","baseline"], value=DEFAULT_MODEL, label="เลือกโมเดล")
59
+ out_probs = gr.Label(label="Probabilities")
60
+ out_label = gr.Textbox(label="Prediction", interactive=False)
61
+ gr.Button("Predict").click(predict_api, [inp_text, inp_model], [out_probs, out_label])
62
+
63
+ if __name__ == "__main__":
64
+ demo.launch()
requirements.txt ADDED
@@ -0,0 +1,5 @@
 
 
 
 
 
 
1
+ torch
2
+ transformers
3
+ safetensors
4
+ gradio
5
+ huggingface_hub