oriqqqqqqat commited on
Commit
7eadd35
·
1 Parent(s): b57ca85

modifymain

Browse files
Files changed (1) hide show
  1. main.py +66 -153
main.py CHANGED
@@ -21,14 +21,13 @@ from fastapi import FastAPI, File, UploadFile, Form, Request, Depends
21
  from fastapi.responses import HTMLResponse, RedirectResponse
22
  from fastapi.templating import Jinja2Templates
23
  from fastapi.staticfiles import StaticFiles
24
- import requests
25
- import uvicorn
26
 
27
  sys.path.append(os.path.abspath(os.path.dirname(__file__)))
28
  from models.densenet.preprocess.preprocessingwangchan import get_tokenizer, get_transforms
29
  from models.densenet.train_densenet_only import DenseNet121Classifier
30
  from models.densenet.train_text_only import TextClassifier
31
-
 
32
 
33
  HF_MODEL_URL = "https://huggingface.co/qqqqqqat/densenet_wangchan/resolve/main/best_fusion_densenet.pth"
34
  LOCAL_MODEL_PATH = "models/densenet/best_fusion_densenet.pth"
@@ -50,13 +49,10 @@ FUSION_LABELMAP_PATH = "models/densenet/label_map_fusion_densenet.json"
50
  FUSION_WEIGHTS_PATH = "models/densenet/best_fusion_densenet.pth"
51
  with open(FUSION_LABELMAP_PATH, "r", encoding="utf-8") as f:
52
  label_map = json.load(f)
53
-
54
  class_names = [label for label, _ in sorted(label_map.items(), key=lambda x: x[1])]
55
  NUM_CLASSES = len(class_names)
56
  device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
57
  print(f"🧠 Using device: {device}")
58
-
59
-
60
  class FusionDenseNetText(nn.Module):
61
  def __init__(self, num_classes, dropout=0.3):
62
  super().__init__()
@@ -72,11 +68,9 @@ class FusionDenseNetText(nn.Module):
72
  fused_in = torch.cat([logits_img, logits_txt], dim=1)
73
  fused_out = self.fusion(fused_in)
74
  return fused_out, logits_img, logits_txt
75
-
76
-
77
  print("🔄 Loading AI model...")
78
  fusion_model = FusionDenseNetText(num_classes=NUM_CLASSES).to(device)
79
-
80
  download_model_if_needed()
81
 
82
  fusion_model = FusionDenseNetText(num_classes=NUM_CLASSES).to(device)
@@ -84,55 +78,38 @@ fusion_model.load_state_dict(torch.load(LOCAL_MODEL_PATH, map_location=device))
84
  fusion_model.eval()
85
  print("✅ AI Model loaded successfully!")
86
  fusion_model.eval()
87
-
88
  tokenizer = get_tokenizer()
89
  transform = get_transforms((224, 224))
90
-
91
-
92
  def _find_last_conv2d(mod: torch.nn.Module):
93
  last = None
94
  for m in mod.modules():
95
  if isinstance(m, torch.nn.Conv2d): last = m
96
  return last
97
-
98
-
99
  def compute_gradcam_overlay(img_pil, image_tensor, target_class_idx):
100
  img_branch = fusion_model.image_model
101
  target_layer = _find_last_conv2d(img_branch)
102
  if target_layer is None: return None
103
-
104
  activations, gradients = [], []
105
-
106
  def fwd_hook(_m, _i, o): activations.append(o)
107
  def bwd_hook(_m, gin, gout): gradients.append(gout[0])
108
-
109
  h1 = target_layer.register_forward_hook(fwd_hook)
110
  h2 = target_layer.register_full_backward_hook(bwd_hook)
111
-
112
  try:
113
  img_branch.zero_grad()
114
  logits_img = img_branch(image_tensor)
115
  score = logits_img[0, target_class_idx]
116
  score.backward()
117
-
118
  act = activations[-1].detach()[0]
119
  grad = gradients[-1].detach()[0]
120
-
121
  weights = torch.mean(grad, dim=(1, 2))
122
  cam = torch.relu(torch.sum(weights[:, None, None] * act, dim=0))
123
-
124
  cam -= cam.min(); cam /= (cam.max() + 1e-8)
125
-
126
  cam_img = Image.fromarray((cam.cpu().numpy() * 255).astype(np.uint8)).resize(img_pil.size, Image.BILINEAR)
127
-
128
  cam_np = np.asarray(cam_img).astype(np.float32) / 255.0
129
  heatmap = cm.get_cmap("jet")(cam_np)[:, :, :3]
130
-
131
  img_np = np.asarray(img_pil.convert("RGB")).astype(np.float32) / 255.0
132
  overlay = (0.6 * img_np + 0.4 * heatmap)
133
-
134
  return np.clip(overlay * 255, 0, 255).astype(np.uint8)
135
-
136
  finally:
137
  h1.remove(); h2.remove(); img_branch.zero_grad()
138
 
@@ -142,191 +119,129 @@ app.mount("/static", StaticFiles(directory="static"), name="static")
142
  templates = Jinja2Templates(directory="templates")
143
  os.makedirs("uploads", exist_ok=True)
144
 
145
-
146
- # ===== Cache Cleanup Thread =====
147
- EXPIRATION_MINUTES = 10
148
- results_cache = {}
149
- cache_lock = threading.Lock()
150
-
151
 
152
  def cleanup_expired_cache():
 
 
 
153
  while True:
154
- with cache_lock:
 
155
  expired_keys = []
156
  current_time = time.time()
157
  for key, value in results_cache.items():
158
  if current_time - value["created_at"] > EXPIRATION_MINUTES * 60:
159
  expired_keys.append(key)
160
-
 
161
  for key in expired_keys:
162
  del results_cache[key]
163
  print(f"🧹 Cache expired and removed for key: {key}")
164
-
165
- time.sleep(60)
166
-
167
 
168
  @app.on_event("startup")
169
  async def startup_event():
 
 
 
170
  cleanup_thread = threading.Thread(target=cleanup_expired_cache, daemon=True)
171
  cleanup_thread.start()
172
  print("🗑️ Cache cleanup task started.")
173
 
174
-
175
- # ===============================
176
- # >>> TURNSTILE VERIFY <<<
177
- # ===============================
178
-
179
- TURNSTILE_SECRET = "0x4AAAAAACEfyIeAjGlYCXeasGsMxTuTlHU" ### (เพิ่ม)
180
  def process_with_ai_model(image_path: str, prompt_text: str):
181
- """ประมวลผลภาพด้วย DenseNet + Text model และสร้างผลลัพธ์"""
182
-
183
- # --- Load image ---
184
- img_pil = Image.open(image_path).convert("RGB")
185
- img_tensor = transform(img_pil).unsqueeze(0).to(device)
186
-
187
- # --- Tokenize text prompt ---
188
- encoding = tokenizer(
189
- prompt_text,
190
- padding="max_length",
191
- truncation=True,
192
- max_length=128,
193
- return_tensors="pt"
194
- )
195
- input_ids = encoding["input_ids"].to(device)
196
- attention_mask = encoding["attention_mask"].to(device)
197
-
198
- # --- Forward pass ---
199
- with torch.no_grad():
200
- fused_out, logits_img, logits_txt = fusion_model(
201
- img_tensor, input_ids, attention_mask
202
- )
203
- probs = torch.softmax(fused_out, dim=1)[0]
204
- pred_idx = torch.argmax(probs).item()
205
-
206
- predicted_label = class_names[pred_idx]
207
- confidence = probs[pred_idx].item()
208
-
209
- name_out = f"{predicted_label} ({confidence:.2f})"
210
-
211
- # --- Compute GradCAM ---
212
- overlay_np = compute_gradcam_overlay(img_pil, img_tensor, pred_idx)
213
- if overlay_np is not None:
214
- overlay_img = Image.fromarray(overlay_np)
215
- buffered = BytesIO()
216
- overlay_img.save(buffered, format="PNG")
217
- gradcam_b64 = base64.b64encode(buffered.getvalue()).decode()
218
- else:
219
- gradcam_b64 = None
220
-
221
- # --- Encode original image ---
222
- buffer2 = BytesIO()
223
- img_pil.save(buffer2, format="PNG")
224
- image_b64 = base64.b64encode(buffer2.getvalue()).decode()
225
-
226
- return image_b64, gradcam_b64, predicted_label, confidence
227
-
228
 
229
  @app.post("/uploaded")
230
  async def handle_upload(
231
  request: Request,
232
  file: UploadFile = File(...),
233
  checkboxes: List[str] = Form([]),
234
- symptom_text: str = Form(""),
235
- cf_turnstile_response: str = Form("") ### (เพิ่ม)
236
  ):
237
-
238
- # ---------- (เพิ่ม) VERIFY TURNSTILE ---------
239
-
240
- if not cf_turnstile_response:
241
- return templates.TemplateResponse(
242
- "detect.html",
243
- {"request": request, "error": "Turnstile token missing"}
244
- )
245
-
246
- verify_resp = requests.post(
247
- "https://challenges.cloudflare.com/turnstile/v0/siteverify",
248
- data={
249
- "secret": TURNSTILE_SECRET,
250
- "response": cf_turnstile_response
251
- }
252
- )
253
-
254
- cf_result = verify_resp.json()
255
-
256
- if not cf_result.get("success"):
257
- return templates.TemplateResponse(
258
- "detect.html",
259
- {"request": request, "error": "การยืนยันความปลอดภัยไม่สำเร็จ กรุณาลองใหม่อีกครั้ง"}
260
- )
261
-
262
- # ------------------------------------------------
263
-
264
-
265
- # >>> โค้ดเดิมของคุณทั้งหมดด้านล่าง (ไม่แก้) <<<
266
-
267
  temp_filepath = os.path.join("uploads", f"{uuid.uuid4()}_{file.filename}")
268
  with open(temp_filepath, "wb") as buffer:
269
  shutil.copyfileobj(file.file, buffer)
270
-
271
- SYMPTOM_MAP = {
272
- "noSymptoms": "ไม่มีอาการ", "drinkAlcohol": "ดื่มเหล้า", "smoking": "สูบบุหรี่",
273
- "chewBetelNut": "เคี้ยวหมาก", "eatSpicyFood": "กินเผ็ดแสบ", "wipeOff": "เช็ดออกได้",
274
- "alwaysHurts": "เจ็บเมื่อโดนแผล"
275
- }
276
-
277
  final_prompt_parts = []
278
  selected_symptoms_thai = {SYMPTOM_MAP.get(cb) for cb in checkboxes if SYMPTOM_MAP.get(cb)}
279
-
280
  if "ไม่มีอาการ" in selected_symptoms_thai:
281
  symptoms_group = {"เจ็บเมื่อโดนแผล", "กินเผ็ดแสบ"}
282
  lifestyles_group = {"ดื่มเหล้า", "สูบบุหรี่", "เคี้ยวหมาก"}
283
  patterns_group = {"เช็ดออกได้"}
284
  special_group = {"ไม่มีอาการ"}
285
-
286
  final_selected = (selected_symptoms_thai - symptoms_group) | \
287
  (selected_symptoms_thai & (lifestyles_group | patterns_group | special_group))
288
-
289
  final_prompt_parts.append(" ".join(sorted(list(final_selected))))
290
-
291
  elif selected_symptoms_thai:
292
  final_prompt_parts.append(" ".join(sorted(list(selected_symptoms_thai))))
293
-
294
  if symptom_text and symptom_text.strip():
295
  final_prompt_parts.append(symptom_text.strip())
296
-
297
  final_prompt = "; ".join(final_prompt_parts) if final_prompt_parts else "ไม่มีอาการ"
298
-
299
  image_b64, gradcam_b64, name_out, eva_output = process_with_ai_model(
300
  image_path=temp_filepath, prompt_text=final_prompt
301
  )
302
-
303
  os.remove(temp_filepath)
304
-
305
  result_id = str(uuid.uuid4())
306
-
307
  result_data = {
308
- "image_b64_data": image_b64,
309
- "gradcam_b64_data": gradcam_b64,
310
- "name_out": name_out,
311
- "eva_output": eva_output,
312
  }
313
-
314
  with cache_lock:
315
  results_cache[result_id] = {
316
  "data": result_data,
317
- "created_at": time.time()
318
  }
319
 
320
- results_url = request.url_for("show_results", result_id=result_id)
321
  return RedirectResponse(url=results_url, status_code=303)
322
 
323
-
324
-
325
  @app.get("/results/{result_id}", response_class=HTMLResponse)
326
  async def show_results(request: Request, result_id: str):
327
  with cache_lock:
328
  cached_item = results_cache.get(result_id)
329
-
330
  if not cached_item or (time.time() - cached_item["created_at"] > EXPIRATION_MINUTES * 60):
331
  if cached_item:
332
  with cache_lock:
@@ -336,8 +251,6 @@ async def show_results(request: Request, result_id: str):
336
  context = {"request": request, **cached_item["data"]}
337
  return templates.TemplateResponse("detect.html", context)
338
 
339
-
340
-
341
  if __name__ == "__main__":
342
- port = int(os.environ.get("PORT", 8000))
343
- uvicorn.run(app, host="0.0.0.0", port=port)
 
21
  from fastapi.responses import HTMLResponse, RedirectResponse
22
  from fastapi.templating import Jinja2Templates
23
  from fastapi.staticfiles import StaticFiles
 
 
24
 
25
  sys.path.append(os.path.abspath(os.path.dirname(__file__)))
26
  from models.densenet.preprocess.preprocessingwangchan import get_tokenizer, get_transforms
27
  from models.densenet.train_densenet_only import DenseNet121Classifier
28
  from models.densenet.train_text_only import TextClassifier
29
+ import requests
30
+ import uvicorn
31
 
32
  HF_MODEL_URL = "https://huggingface.co/qqqqqqat/densenet_wangchan/resolve/main/best_fusion_densenet.pth"
33
  LOCAL_MODEL_PATH = "models/densenet/best_fusion_densenet.pth"
 
49
  FUSION_WEIGHTS_PATH = "models/densenet/best_fusion_densenet.pth"
50
  with open(FUSION_LABELMAP_PATH, "r", encoding="utf-8") as f:
51
  label_map = json.load(f)
 
52
  class_names = [label for label, _ in sorted(label_map.items(), key=lambda x: x[1])]
53
  NUM_CLASSES = len(class_names)
54
  device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
55
  print(f"🧠 Using device: {device}")
 
 
56
  class FusionDenseNetText(nn.Module):
57
  def __init__(self, num_classes, dropout=0.3):
58
  super().__init__()
 
68
  fused_in = torch.cat([logits_img, logits_txt], dim=1)
69
  fused_out = self.fusion(fused_in)
70
  return fused_out, logits_img, logits_txt
 
 
71
  print("🔄 Loading AI model...")
72
  fusion_model = FusionDenseNetText(num_classes=NUM_CLASSES).to(device)
73
+ # ดาวน์โหลดก่อน
74
  download_model_if_needed()
75
 
76
  fusion_model = FusionDenseNetText(num_classes=NUM_CLASSES).to(device)
 
78
  fusion_model.eval()
79
  print("✅ AI Model loaded successfully!")
80
  fusion_model.eval()
 
81
  tokenizer = get_tokenizer()
82
  transform = get_transforms((224, 224))
 
 
83
  def _find_last_conv2d(mod: torch.nn.Module):
84
  last = None
85
  for m in mod.modules():
86
  if isinstance(m, torch.nn.Conv2d): last = m
87
  return last
 
 
88
  def compute_gradcam_overlay(img_pil, image_tensor, target_class_idx):
89
  img_branch = fusion_model.image_model
90
  target_layer = _find_last_conv2d(img_branch)
91
  if target_layer is None: return None
 
92
  activations, gradients = [], []
 
93
  def fwd_hook(_m, _i, o): activations.append(o)
94
  def bwd_hook(_m, gin, gout): gradients.append(gout[0])
 
95
  h1 = target_layer.register_forward_hook(fwd_hook)
96
  h2 = target_layer.register_full_backward_hook(bwd_hook)
 
97
  try:
98
  img_branch.zero_grad()
99
  logits_img = img_branch(image_tensor)
100
  score = logits_img[0, target_class_idx]
101
  score.backward()
 
102
  act = activations[-1].detach()[0]
103
  grad = gradients[-1].detach()[0]
 
104
  weights = torch.mean(grad, dim=(1, 2))
105
  cam = torch.relu(torch.sum(weights[:, None, None] * act, dim=0))
 
106
  cam -= cam.min(); cam /= (cam.max() + 1e-8)
 
107
  cam_img = Image.fromarray((cam.cpu().numpy() * 255).astype(np.uint8)).resize(img_pil.size, Image.BILINEAR)
 
108
  cam_np = np.asarray(cam_img).astype(np.float32) / 255.0
109
  heatmap = cm.get_cmap("jet")(cam_np)[:, :, :3]
 
110
  img_np = np.asarray(img_pil.convert("RGB")).astype(np.float32) / 255.0
111
  overlay = (0.6 * img_np + 0.4 * heatmap)
 
112
  return np.clip(overlay * 255, 0, 255).astype(np.uint8)
 
113
  finally:
114
  h1.remove(); h2.remove(); img_branch.zero_grad()
115
 
 
119
  templates = Jinja2Templates(directory="templates")
120
  os.makedirs("uploads", exist_ok=True)
121
 
122
+ EXPIRATION_MINUTES = 10
123
+ results_cache = {}
124
+ cache_lock = threading.Lock()
 
 
 
125
 
126
  def cleanup_expired_cache():
127
+ """
128
+ ฟังก์ชันนี้จะทำงานใน Background Thread เพื่อตรวจสอบและลบ Cache ที่หมดอายุ
129
+ """
130
  while True:
131
+ with cache_lock: # ล็อคเพื่อความปลอดภัยในการเข้าถึง cache
132
+ # สร้าง list ของ key ที่จะลบ เพื่อไม่ให้แก้ไข dict ขณะวน loop
133
  expired_keys = []
134
  current_time = time.time()
135
  for key, value in results_cache.items():
136
  if current_time - value["created_at"] > EXPIRATION_MINUTES * 60:
137
  expired_keys.append(key)
138
+
139
+ # ลบ key ที่หมดอายุ
140
  for key in expired_keys:
141
  del results_cache[key]
142
  print(f"🧹 Cache expired and removed for key: {key}")
143
+
144
+ time.sleep(60) # ตรวจสอบทุกๆ 60 วินาที
 
145
 
146
  @app.on_event("startup")
147
  async def startup_event():
148
+ """
149
+ เริ่ม Background Thread สำหรับทำความสะอาด Cache เมื่อแอปเริ่มทำงาน
150
+ """
151
  cleanup_thread = threading.Thread(target=cleanup_expired_cache, daemon=True)
152
  cleanup_thread.start()
153
  print("🗑️ Cache cleanup task started.")
154
 
155
+ SYMPTOM_MAP = {
156
+ "noSymptoms": "ไม่มีอาการ", "drinkAlcohol": "ดื่มเหล้า", "smoking": "สูบบุหรี่",
157
+ "chewBetelNut": "เคี้ยวหมาก", "eatSpicyFood": "กินเผ็ดแสบ", "wipeOff": "เช็ดออกได้",
158
+ "alwaysHurts": "เจ็บเมื่อโดนแผล"
159
+ }
 
160
  def process_with_ai_model(image_path: str, prompt_text: str):
161
+ try:
162
+ image_pil = Image.open(image_path)
163
+ image_pil = ImageOps.exif_transpose(image_pil)
164
+ image_pil = image_pil.convert("RGB")
165
+ image_tensor = transform(image_pil).unsqueeze(0).to(device)
166
+ enc = tokenizer(prompt_text, return_tensors="pt", padding="max_length",
167
+ truncation=True, max_length=128)
168
+ ids, mask = enc["input_ids"].to(device), enc["attention_mask"].to(device)
169
+ with torch.no_grad():
170
+ fused_logits, _, _ = fusion_model(image_tensor, ids, mask)
171
+ probs_fused = torch.softmax(fused_logits, dim=1)[0].cpu().numpy()
172
+ pred_idx = int(np.argmax(probs_fused))
173
+ pred_label = class_names[pred_idx]
174
+ confidence = float(probs_fused[pred_idx]) * 100
175
+ gradcam_overlay_np = compute_gradcam_overlay(image_pil, image_tensor, pred_idx)
176
+ def image_to_base64(img):
177
+ buffered = BytesIO()
178
+ img.save(buffered, format="JPEG")
179
+ return base64.b64encode(buffered.getvalue()).decode('utf-8')
180
+ original_b64 = image_to_base64(image_pil)
181
+ if gradcam_overlay_np is not None:
182
+ gradcam_pil = Image.fromarray(gradcam_overlay_np)
183
+ gradcam_b64 = image_to_base64(gradcam_pil)
184
+ else:
185
+ gradcam_b64 = original_b64
186
+ return original_b64, gradcam_b64, pred_label, f"{confidence:.2f}"
187
+ except Exception as e:
188
+ print(f"❌ Error during AI processing: {e}")
189
+ return None, None, "Error", "0.00"
190
+
191
+ @app.get("/", response_class=RedirectResponse)
192
+ async def root():
193
+ return RedirectResponse(url="/detect")
194
+ @app.get("/detect", response_class=HTMLResponse)
195
+ async def show_upload_form(request: Request):
196
+ return templates.TemplateResponse("detect.html", {"request": request})
 
 
 
 
 
 
 
 
 
 
 
197
 
198
  @app.post("/uploaded")
199
  async def handle_upload(
200
  request: Request,
201
  file: UploadFile = File(...),
202
  checkboxes: List[str] = Form([]),
203
+ symptom_text: str = Form("")
 
204
  ):
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
205
  temp_filepath = os.path.join("uploads", f"{uuid.uuid4()}_{file.filename}")
206
  with open(temp_filepath, "wb") as buffer:
207
  shutil.copyfileobj(file.file, buffer)
 
 
 
 
 
 
 
208
  final_prompt_parts = []
209
  selected_symptoms_thai = {SYMPTOM_MAP.get(cb) for cb in checkboxes if SYMPTOM_MAP.get(cb)}
 
210
  if "ไม่มีอาการ" in selected_symptoms_thai:
211
  symptoms_group = {"เจ็บเมื่อโดนแผล", "กินเผ็ดแสบ"}
212
  lifestyles_group = {"ดื่มเหล้า", "สูบบุหรี่", "เคี้ยวหมาก"}
213
  patterns_group = {"เช็ดออกได้"}
214
  special_group = {"ไม่มีอาการ"}
 
215
  final_selected = (selected_symptoms_thai - symptoms_group) | \
216
  (selected_symptoms_thai & (lifestyles_group | patterns_group | special_group))
 
217
  final_prompt_parts.append(" ".join(sorted(list(final_selected))))
 
218
  elif selected_symptoms_thai:
219
  final_prompt_parts.append(" ".join(sorted(list(selected_symptoms_thai))))
 
220
  if symptom_text and symptom_text.strip():
221
  final_prompt_parts.append(symptom_text.strip())
 
222
  final_prompt = "; ".join(final_prompt_parts) if final_prompt_parts else "ไม่มีอาการ"
 
223
  image_b64, gradcam_b64, name_out, eva_output = process_with_ai_model(
224
  image_path=temp_filepath, prompt_text=final_prompt
225
  )
 
226
  os.remove(temp_filepath)
 
227
  result_id = str(uuid.uuid4())
 
228
  result_data = {
229
+ "image_b64_data": image_b64, "gradcam_b64_data": gradcam_b64,
230
+ "name_out": name_out, "eva_output": eva_output,
 
 
231
  }
 
232
  with cache_lock:
233
  results_cache[result_id] = {
234
  "data": result_data,
235
+ "created_at": time.time()
236
  }
237
 
238
+ results_url = request.url_for('show_results', result_id=result_id)
239
  return RedirectResponse(url=results_url, status_code=303)
240
 
 
 
241
  @app.get("/results/{result_id}", response_class=HTMLResponse)
242
  async def show_results(request: Request, result_id: str):
243
  with cache_lock:
244
  cached_item = results_cache.get(result_id)
 
245
  if not cached_item or (time.time() - cached_item["created_at"] > EXPIRATION_MINUTES * 60):
246
  if cached_item:
247
  with cache_lock:
 
251
  context = {"request": request, **cached_item["data"]}
252
  return templates.TemplateResponse("detect.html", context)
253
 
 
 
254
  if __name__ == "__main__":
255
+ port = int(os.environ.get("PORT", 8000)) # ใช้ PORT ของ hosting ถ้ามี
256
+ uvicorn.run(app, host="0.0.0.0", port=port)