# app.py — REST EEG Seizure Demo (GraphConv, supports W2/U, stats + top-k + span) import os, io, json import numpy as np import gradio as gr import torch import torch.nn as nn import torch.nn.functional as F import h5py, matplotlib.pyplot as plt from huggingface_hub import hf_hub_download from torch_geometric.nn import GraphConv # --------------------------- # 0) Config & device # --------------------------- MODEL_REPO = os.getenv("MODEL_REPO", "uyen1109/rest_eeg_seizure_analysis") SPACE_DIR = "space_infer" device = torch.device("cpu") torch.set_num_threads(max(1, os.cpu_count() // 2)) # ========================================================== # 1) Model (GraphConv + tên tham số khớp checkpoint, optional W2/U) # Bắt buộc: W1, gc1, gc2, fc # Tuỳ chọn: W2 (train có thêm) & U (transition trên S) # ========================================================== class RESTNet(nn.Module): def __init__( self, in_dim: int, state_q: int = 64, w2_in: int | None = None, # None => không dùng W2; nếu có: in_dim (x_t) hoặc state_q (S) u_in: int | None = None, # None => không dùng U; nếu có: state_q ): super().__init__() self.in_dim = in_dim self.state_q = state_q # Tên trùng checkpoint self.W1 = nn.Linear(in_dim, state_q) self.gc1 = GraphConv(state_q, state_q, aggr="mean") self.gc2 = GraphConv(state_q, state_q, aggr="mean") self.fc = nn.Linear(state_q, 1) # Optional: W2 & U (nếu tồn tại trong checkpoint) self._use_W2 = False self._use_U = False if w2_in is not None: if w2_in == in_dim: # W2 hoạt động trên x_t self.W2 = nn.Linear(in_dim, state_q) self._use_W2 = True elif w2_in == state_q: # W2 hoạt động trên S self.W2 = nn.Linear(state_q, state_q) self._use_W2 = True if u_in is not None and u_in == state_q: # không bias để tên khớp 'U.weight' trong ckpt self.U = nn.Linear(state_q, state_q, bias=False) self._use_U = True @torch.no_grad() def forward(self, x_ntf, edge_index, edge_weight=None): # x_ntf: [N, T, F] N, T, Fdim = x_ntf.shape S = torch.zeros(N, self.state_q, device=x_ntf.device) frame_logits = [] for t in range(T): upd = self.W1(x_ntf[:, t, :]) # luôn có if self._use_W2: if self.W2.in_features == Fdim: upd = upd + self.W2(x_ntf[:, t, :]) # W2 trên x_t else: upd = upd + self.W2(S) # W2 trên S if self._use_U: S = self.U(S) + upd else: S = S + upd # message passing (GraphConv như khi train) S = F.relu(self.gc1(S, edge_index, edge_weight)) S = F.relu(self.gc2(S, edge_index, edge_weight)) frame_logits.append(self.fc(S).mean(dim=0, keepdim=True).squeeze(-1)) frame_logits = torch.stack(frame_logits, dim=0) # [T] frame_probs = torch.sigmoid(frame_logits) return frame_probs.mean().item(), frame_probs # ========================================================== # 2) Load config + weights (strict load, auto-enable W2/U) # ========================================================== def hub_download(fname: str) -> str: return hf_hub_download(MODEL_REPO, f"{SPACE_DIR}/{fname}", repo_type="model") CFG_PATH = hub_download("rest_config.json") STATE_PATH = hub_download("rest_state.pt") with open(CFG_PATH, "r") as f: CFG = json.load(f) sd = torch.load(STATE_PATH, map_location="cpu") # Suy ra F (in_dim) và Q (state_q) từ checkpoint if "W1.weight" in sd: Q_MODEL, F_MODEL = sd["W1.weight"].shape # [Q, F] else: F_MODEL = int(CFG.get("in_feat", 128)) Q_MODEL = int(CFG.get("state_q", 32)) # Phát hiện W2/U có trong checkpoint và lấy in_features để dựng lớp tương ứng w2_in = sd["W2.weight"].shape[1] if "W2.weight" in sd else None u_in = sd["U.weight"].shape[1] if "U.weight" in sd else None MODEL = RESTNet(in_dim=F_MODEL, state_q=Q_MODEL, w2_in=w2_in, u_in=u_in).to(device).eval() # Nạp checkpoint nghiêm ngặt; nếu fail vì key phụ không dùng, fallback non-strict để Space không crash try: missing, unexpected = MODEL.load_state_dict(sd, strict=True) print(f"Loaded strict: F={F_MODEL}, Q={Q_MODEL}, W2_in={w2_in}, U_in={u_in}") print("Missing keys:", missing) print("Unexpected keys:", unexpected) except Exception as e: print("Strict load failed:", e) missing, unexpected = MODEL.load_state_dict(sd, strict=False) print("Fallback non-strict load. Missing:", missing, "Unexpected:", unexpected) # ========================================================== # 3) Data utilities + viz (stats/top-k/spans) # ========================================================== def _normalize_x_shape(x: np.ndarray) -> np.ndarray: """ Đưa x về [N, T, F] theo heuristic: - N ~ [8..128] (số kênh EEG) - T >= 2 - F >= 1 """ assert x.ndim == 3, "Input x phải có 3 chiều." perms = [(0,1,2),(0,2,1),(1,0,2),(1,2,0),(2,0,1),(2,1,0)] for p in perms: y = np.transpose(x, p) N, T, Fdim = y.shape if 8 <= N <= 128 and T >= 2 and Fdim >= 1: return y return x # fallback nếu không đoán được def _adapt_features_to_model(x_ntf: np.ndarray, target_F: int) -> np.ndarray: """ Pad hoặc cắt trục F để khớp target_F (model.in_dim). """ N, T, Fdim = x_ntf.shape if Fdim == target_F: return x_ntf if Fdim > target_F: return x_ntf[..., :target_F] pad = target_F - Fdim return np.pad(x_ntf, ((0,0),(0,0),(0,pad)), mode="constant") def load_npz(npz_file): if isinstance(npz_file, str): npz = np.load(npz_file, allow_pickle=True) else: buf = io.BytesIO(npz_file.read()) npz = np.load(buf, allow_pickle=True) x = np.asarray(npz["x"]) x = _normalize_x_shape(x) edge_index = np.asarray(npz["edge_index"]) edge_weight = np.asarray(npz["edge_weight"]) if "edge_weight" in npz else None return x, edge_index, edge_weight def load_h5(h5_file, clip_idx=0): if isinstance(h5_file, str): f = h5py.File(h5_file, "r") else: buf = io.BytesIO(h5_file.read()) f = h5py.File(buf, "r") keys = list(f.keys()) for k in ["x","clips","X"]: if k in keys: X = f[k] break else: raise gr.Error("Không tìm thấy dataset 'x'/'clips'/'X' trong H5.") # Lấy đúng clip if X.ndim == 3: x = X[:] # [N,T,F?] else: x = X[clip_idx] # [N,T,F?] x = _normalize_x_shape(x) for k in ["edge_index","edge_idx","edges"]: if k in keys: edge_index = f[k][:] break else: raise gr.Error("Không có 'edge_index' trong H5.") edge_weight = None for k in ["edge_weight","edge_w","weights"]: if k in keys: edge_weight = f[k][:] break f.close() return x, edge_index, edge_weight def _cluster_spans_from_top(top_idx: np.ndarray, T: int, span_half: int, merge_gap: int): """ Từ danh sách frame top-k (đã sort tăng), tạo các span [l,r] với padding 'span_half', và merge nếu khoảng cách giữa các span liền kề <= merge_gap. """ if top_idx is None or len(top_idx) == 0: return [] span_half = max(0, int(span_half)) merge_gap = max(0, int(merge_gap)) # Tạo các interval cơ bản spans = [] for t in top_idx: l = max(0, int(t) - span_half) r = min(T - 1, int(t) + span_half) spans.append([l, r]) # Merge các interval nếu gần nhau spans.sort(key=lambda x: x[0]) merged = [] cur_l, cur_r = spans[0] for l, r in spans[1:]: if l <= cur_r + merge_gap: cur_r = max(cur_r, r) else: merged.append([cur_l, cur_r]) cur_l, cur_r = l, r merged.append([cur_l, cur_r]) return merged def plot_frame_probs(frame_probs: np.ndarray, top_idx: np.ndarray | None = None, spans: list[list[int]] | None = None) -> np.ndarray: # Trả về numpy array (HxWx3) để hợp với gr.Image(type="numpy") fig = plt.figure(figsize=(7, 2.8)) ax = fig.add_subplot(111) t = np.arange(len(frame_probs)) ax.plot(t, frame_probs, lw=1.5) # Tô các span (nếu có) trước để đường nằm phía trên if spans: for l, r in spans: ax.axvspan(l, r, alpha=0.15) # Đánh dấu top-k if top_idx is not None and len(top_idx) > 0: ax.scatter(top_idx, frame_probs[top_idx], s=24, marker="o") ax.set_title("Frame-wise seizure probability") ax.set_xlabel("Frame index (t)") ax.set_ylabel("p(seizure)") ax.set_ylim(0, 1) ax.grid(True, alpha=0.3) fig.tight_layout() # Lấy ảnh từ canvas (ổn định trên Matplotlib mới) fig.canvas.draw() w, h = fig.canvas.get_width_height() buf = np.frombuffer(fig.canvas.buffer_rgba(), dtype=np.uint8) img = buf.reshape(h, w, 4)[..., :3].copy() # drop alpha → RGB plt.close(fig) return img def summarize_probs(frame_probs: np.ndarray, top_k: int = 10, span_half: int = 12, merge_gap: int = 5) -> tuple[str, np.ndarray, list[list[int]]]: """ Trả về (markdown thống kê, top_idx theo p(t) giảm dần, danh sách spans). An toàn khi số frame rất ít (T=0/1) và khi top_k > T. """ # Ép về 1D & float p = np.asarray(frame_probs, dtype=float).reshape(-1) M = int(p.size) if M == 0: md = "**Frame stats** \n- (no frames)\n" return md, np.array([], dtype=int), [] # Clamp k theo số frame try: k = int(top_k) except Exception: k = 10 k = max(1, min(k, M)) kth = max(0, min(k - 1, M - 1)) # Lấy top-k an toàn if M == 1: top_idx = np.array([0], dtype=int) else: # argpartition rồi sort giảm dần theo p top_idx = np.argpartition(-p, kth)[:k] top_idx = top_idx[np.argsort(-p[top_idx], kind="mergesort")] # Tạo spans từ top-k (đã sort tăng để gộp) spans = _cluster_spans_from_top(top_idx=np.sort(top_idx), T=M, span_half=int(max(0, span_half)), merge_gap=int(max(0, merge_gap))) # Thống kê mean = float(np.mean(p)); std = float(np.std(p)) pmin = float(np.min(p)); imin = int(np.argmin(p)) pmax = float(np.max(p)); imax = int(np.argmax(p)) rows = "\n".join([f"| {i} | {p[i]:.4f} |" for i in top_idx]) span_rows = "\n".join([f"- [{l}, {r}] (len={r-l+1})" for l, r in spans]) if spans else "- (none)" md = ( f"**Frame stats** \n" f"- frames: **{M}** \n" f"- mean: **{mean:.4f}** · std: **{std:.4f}** \n" f"- max: **{pmax:.4f}** tại **t={imax}** · min: **{pmin:.4f}** tại **t={imin}** \n" f"- top-{k} frames (of {M}): \n\n" f"| frame t | p(t) |\n|---:|---:|\n{rows if rows else '| - | - |'}\n\n" f"**Merged spans** từ top-k (pad=±{int(span_half)}, merge_gap≤{int(merge_gap)}):\n{span_rows}\n" ) return md, top_idx, spans # ========================================================== # 4) Demo files (từ repo) # ========================================================== DEMO_FILES = [] for i in range(3): try: DEMO_FILES.append(hf_hub_download(MODEL_REPO, f"{SPACE_DIR}/demo_clip{i}.npz", repo_type="model")) except Exception: pass print("Model in_dim:", MODEL.in_dim, "state_q:", MODEL.state_q) # ========================================================== # 5) Inference handlers # ========================================================== def infer_demo(demo_id, top_k, span_half, merge_gap): if not DEMO_FILES: raise gr.Error("Không có demo_clip*.npz trong repo. Hãy export ở bước A.") path = DEMO_FILES[int(demo_id)] x, ei, ew = load_npz(path) x = _adapt_features_to_model(x, MODEL.in_dim) x = torch.tensor(x, dtype=torch.float32, device=device) ei = torch.tensor(ei, dtype=torch.long, device=device) ew = torch.tensor(ew, dtype=torch.float32, device=device) if ew is not None else None clip_p, frame_p = MODEL(x, ei, ew) p_np = frame_p.cpu().numpy() stats_md, top_idx, spans = summarize_probs(p_np, top_k=int(top_k), span_half=int(span_half), merge_gap=int(merge_gap)) img = plot_frame_probs(p_np, top_idx=top_idx, spans=spans) return f"{clip_p:.4f}", stats_md, img def infer_custom(file, file_type, clip_idx, top_k, span_half, merge_gap): if file is None: raise gr.Error("Hãy upload 1 file H5/NPZ hoặc chọn demo.") if file_type == "npz": x, ei, ew = load_npz(file) else: x, ei, ew = load_h5(file, clip_idx=clip_idx) x = _adapt_features_to_model(x, MODEL.in_dim) x = torch.tensor(x, dtype=torch.float32, device=device) ei = torch.tensor(ei, dtype=torch.long, device=device) ew = torch.tensor(ew, dtype=torch.float32, device=device) if ew is not None else None clip_p, frame_p = MODEL(x, ei, ew) p_np = frame_p.cpu().numpy() stats_md, top_idx, spans = summarize_probs(p_np, top_k=int(top_k), span_half=int(span_half), merge_gap=int(merge_gap)) img = plot_frame_probs(p_np, top_idx=top_idx, spans=spans) return f"{clip_p:.4f}", stats_md, img # ========================================================== # 6) Gradio UI # ========================================================== with gr.Blocks(title="REST EEG Seizure Demo") as demo: gr.Markdown("# REST EEG Seizure – CHB-MIT\nDemo chạy trên CPU (Space). Dữ liệu lớn H5 không tự tải để tiết kiệm tài nguyên.") with gr.Tab("Demo"): dsel = gr.Dropdown(choices=[str(i) for i in range(len(DEMO_FILES))], label="Chọn demo clip", value="0" if DEMO_FILES else None) dtopk = gr.Slider(1, 50, value=10, step=1, label="Top-k frames to highlight") dspan = gr.Slider(0, 200, value=12, step=1, label="Span half-width (±frames)") dgap = gr.Slider(0, 50, value=5, step=1, label="Merge spans if gap ≤") dbtn = gr.Button("Run demo") dout = gr.Textbox(label="Clip probability") dstats = gr.Markdown(label="Frame stats") dfig = gr.Image(label="Frame-wise probability", type="numpy") dbtn.click(fn=infer_demo, inputs=[dsel, dtopk, dspan, dgap], outputs=[dout, dstats, dfig]) with gr.Tab("Upload"): ftype = gr.Radio(choices=["npz","h5"], value="npz", label="Loại file") fup = gr.File(label="Upload .npz (x, edge_index, edge_weight) hoặc .h5") cidx = gr.Slider(0, 50, value=0, step=1, label="clip_idx (nếu H5)") utopk = gr.Slider(1, 50, value=10, step=1, label="Top-k frames to highlight") uspan = gr.Slider(0, 200, value=12, step=1, label="Span half-width (±frames)") ugap = gr.Slider(0, 50, value=5, step=1, label="Merge spans if gap ≤") ubtn = gr.Button("Run inference") uout = gr.Textbox(label="Clip probability") ustats = gr.Markdown(label="Frame stats") ufig = gr.Image(label="Frame-wise probability", type="numpy") ubtn.click(fn=infer_custom, inputs=[fup, ftype, cidx, utopk, uspan, ugap], outputs=[uout, ustats, ufig]) demo.launch()