Spaces:
Sleeping
Sleeping
Upload 2 files
Browse files- app.py +64 -0
- requirements.txt +5 -0
app.py
ADDED
|
@@ -0,0 +1,64 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import os, json, importlib.util, torch
|
| 2 |
+
import torch.nn.functional as F
|
| 3 |
+
import gradio as gr
|
| 4 |
+
from huggingface_hub import hf_hub_download
|
| 5 |
+
from safetensors.torch import load_file
|
| 6 |
+
from transformers import AutoTokenizer
|
| 7 |
+
|
| 8 |
+
# ===== ปรับได้ผ่าน Settings > Variables (Environment) =====
|
| 9 |
+
REPO_ID = os.getenv("REPO_ID", "Dusit-P/thai-sentiment-wcb")
|
| 10 |
+
DEFAULT_MODEL = os.getenv("DEFAULT_MODEL", "cnn_bilstm") # หรือ "baseline"
|
| 11 |
+
HF_TOKEN = os.getenv("HF_TOKEN", None) # ถ้าโมเดลเป็น private ให้เพิ่ม secret ชื่อนี้
|
| 12 |
+
|
| 13 |
+
CACHE = {}
|
| 14 |
+
|
| 15 |
+
def _import_models():
|
| 16 |
+
if "models_module" in CACHE:
|
| 17 |
+
return CACHE["models_module"]
|
| 18 |
+
models_py = hf_hub_download(REPO_ID, filename="common/models.py", token=HF_TOKEN)
|
| 19 |
+
spec = importlib.util.spec_from_file_location("models", models_py)
|
| 20 |
+
mod = importlib.util.module_from_spec(spec)
|
| 21 |
+
spec.loader.exec_module(mod)
|
| 22 |
+
CACHE["models_module"] = mod
|
| 23 |
+
return mod
|
| 24 |
+
|
| 25 |
+
def load_model(model_name: str):
|
| 26 |
+
key = f"model:{model_name}"
|
| 27 |
+
if key in CACHE:
|
| 28 |
+
return CACHE[key]
|
| 29 |
+
cfg_path = hf_hub_download(REPO_ID, filename=f"{model_name}/config.json", token=HF_TOKEN)
|
| 30 |
+
w_path = hf_hub_download(REPO_ID, filename=f"{model_name}/model.safetensors", token=HF_TOKEN)
|
| 31 |
+
with open(cfg_path, "r", encoding="utf-8") as f:
|
| 32 |
+
cfg = json.load(f)
|
| 33 |
+
models = _import_models()
|
| 34 |
+
tok = AutoTokenizer.from_pretrained(cfg["base_model"])
|
| 35 |
+
model = models.create_model_by_name(cfg["arch"])
|
| 36 |
+
state = load_file(w_path)
|
| 37 |
+
model.load_state_dict(state, strict=True)
|
| 38 |
+
model.eval()
|
| 39 |
+
CACHE[key] = (model, tok, cfg)
|
| 40 |
+
return CACHE[key]
|
| 41 |
+
|
| 42 |
+
def predict_api(text: str, model_choice: str):
|
| 43 |
+
if not text.strip():
|
| 44 |
+
return {"negative": 0.0, "positive": 0.0}, ""
|
| 45 |
+
model_name = "baseline" if model_choice == "baseline" else "cnn_bilstm"
|
| 46 |
+
model, tok, cfg = load_model(model_name)
|
| 47 |
+
enc = tok([text], padding=True, truncation=True, max_length=cfg["max_len"], return_tensors="pt")
|
| 48 |
+
with torch.no_grad():
|
| 49 |
+
logits = model(enc["input_ids"], enc["attention_mask"])
|
| 50 |
+
probs = F.softmax(logits, dim=1)[0].tolist()
|
| 51 |
+
out = {"negative": float(probs[0]), "positive": float(probs[1])}
|
| 52 |
+
label = "positive" if out["positive"] >= out["negative"] else "negative"
|
| 53 |
+
return out, label
|
| 54 |
+
|
| 55 |
+
with gr.Blocks(title="Thai Sentiment API (Dusit-P)") as demo:
|
| 56 |
+
gr.Markdown("### Thai Sentiment (WangchanBERTa + LSTM Heads)")
|
| 57 |
+
inp_text = gr.Textbox(lines=3, label="ข้อความรีวิวภาษาไทย", placeholder="พิมพ์รีวิวที่นี่")
|
| 58 |
+
inp_model = gr.Radio(choices=["cnn_bilstm","baseline"], value=DEFAULT_MODEL, label="เลือกโมเดล")
|
| 59 |
+
out_probs = gr.Label(label="Probabilities")
|
| 60 |
+
out_label = gr.Textbox(label="Prediction", interactive=False)
|
| 61 |
+
gr.Button("Predict").click(predict_api, [inp_text, inp_model], [out_probs, out_label])
|
| 62 |
+
|
| 63 |
+
if __name__ == "__main__":
|
| 64 |
+
demo.launch()
|
requirements.txt
ADDED
|
@@ -0,0 +1,5 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
torch
|
| 2 |
+
transformers
|
| 3 |
+
safetensors
|
| 4 |
+
gradio
|
| 5 |
+
huggingface_hub
|