oriqqqqqqat commited on
Commit
b57ca85
·
1 Parent(s): a5e2a63

modifymain

Browse files
Files changed (1) hide show
  1. main.py +153 -65
main.py CHANGED
@@ -21,12 +21,14 @@ from fastapi import FastAPI, File, UploadFile, Form, Request, Depends
21
  from fastapi.responses import HTMLResponse, RedirectResponse
22
  from fastapi.templating import Jinja2Templates
23
  from fastapi.staticfiles import StaticFiles
 
 
24
 
25
  sys.path.append(os.path.abspath(os.path.dirname(__file__)))
26
  from models.densenet.preprocess.preprocessingwangchan import get_tokenizer, get_transforms
27
  from models.densenet.train_densenet_only import DenseNet121Classifier
28
  from models.densenet.train_text_only import TextClassifier
29
- import requests
30
 
31
  HF_MODEL_URL = "https://huggingface.co/qqqqqqat/densenet_wangchan/resolve/main/best_fusion_densenet.pth"
32
  LOCAL_MODEL_PATH = "models/densenet/best_fusion_densenet.pth"
@@ -48,10 +50,13 @@ FUSION_LABELMAP_PATH = "models/densenet/label_map_fusion_densenet.json"
48
  FUSION_WEIGHTS_PATH = "models/densenet/best_fusion_densenet.pth"
49
  with open(FUSION_LABELMAP_PATH, "r", encoding="utf-8") as f:
50
  label_map = json.load(f)
 
51
  class_names = [label for label, _ in sorted(label_map.items(), key=lambda x: x[1])]
52
  NUM_CLASSES = len(class_names)
53
  device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
54
  print(f"🧠 Using device: {device}")
 
 
55
  class FusionDenseNetText(nn.Module):
56
  def __init__(self, num_classes, dropout=0.3):
57
  super().__init__()
@@ -67,9 +72,11 @@ class FusionDenseNetText(nn.Module):
67
  fused_in = torch.cat([logits_img, logits_txt], dim=1)
68
  fused_out = self.fusion(fused_in)
69
  return fused_out, logits_img, logits_txt
 
 
70
  print("🔄 Loading AI model...")
71
  fusion_model = FusionDenseNetText(num_classes=NUM_CLASSES).to(device)
72
- # ดาวน์โหลดก่อน
73
  download_model_if_needed()
74
 
75
  fusion_model = FusionDenseNetText(num_classes=NUM_CLASSES).to(device)
@@ -77,38 +84,55 @@ fusion_model.load_state_dict(torch.load(LOCAL_MODEL_PATH, map_location=device))
77
  fusion_model.eval()
78
  print("✅ AI Model loaded successfully!")
79
  fusion_model.eval()
 
80
  tokenizer = get_tokenizer()
81
  transform = get_transforms((224, 224))
 
 
82
  def _find_last_conv2d(mod: torch.nn.Module):
83
  last = None
84
  for m in mod.modules():
85
  if isinstance(m, torch.nn.Conv2d): last = m
86
  return last
 
 
87
  def compute_gradcam_overlay(img_pil, image_tensor, target_class_idx):
88
  img_branch = fusion_model.image_model
89
  target_layer = _find_last_conv2d(img_branch)
90
  if target_layer is None: return None
 
91
  activations, gradients = [], []
 
92
  def fwd_hook(_m, _i, o): activations.append(o)
93
  def bwd_hook(_m, gin, gout): gradients.append(gout[0])
 
94
  h1 = target_layer.register_forward_hook(fwd_hook)
95
  h2 = target_layer.register_full_backward_hook(bwd_hook)
 
96
  try:
97
  img_branch.zero_grad()
98
  logits_img = img_branch(image_tensor)
99
  score = logits_img[0, target_class_idx]
100
  score.backward()
 
101
  act = activations[-1].detach()[0]
102
  grad = gradients[-1].detach()[0]
 
103
  weights = torch.mean(grad, dim=(1, 2))
104
  cam = torch.relu(torch.sum(weights[:, None, None] * act, dim=0))
 
105
  cam -= cam.min(); cam /= (cam.max() + 1e-8)
 
106
  cam_img = Image.fromarray((cam.cpu().numpy() * 255).astype(np.uint8)).resize(img_pil.size, Image.BILINEAR)
 
107
  cam_np = np.asarray(cam_img).astype(np.float32) / 255.0
108
  heatmap = cm.get_cmap("jet")(cam_np)[:, :, :3]
 
109
  img_np = np.asarray(img_pil.convert("RGB")).astype(np.float32) / 255.0
110
  overlay = (0.6 * img_np + 0.4 * heatmap)
 
111
  return np.clip(overlay * 255, 0, 255).astype(np.uint8)
 
112
  finally:
113
  h1.remove(); h2.remove(); img_branch.zero_grad()
114
 
@@ -118,129 +142,191 @@ app.mount("/static", StaticFiles(directory="static"), name="static")
118
  templates = Jinja2Templates(directory="templates")
119
  os.makedirs("uploads", exist_ok=True)
120
 
121
- EXPIRATION_MINUTES = 10
122
- results_cache = {}
123
- cache_lock = threading.Lock()
 
 
 
124
 
125
  def cleanup_expired_cache():
126
- """
127
- ฟังก์ชันนี้จะทำงานใน Background Thread เพื่อตรวจสอบและลบ Cache ที่หมดอายุ
128
- """
129
  while True:
130
- with cache_lock: # ล็อคเพื่อความปลอดภัยในการเข้าถึง cache
131
- # สร้าง list ของ key ที่จะลบ เพื่อไม่ให้แก้ไข dict ขณะวน loop
132
  expired_keys = []
133
  current_time = time.time()
134
  for key, value in results_cache.items():
135
  if current_time - value["created_at"] > EXPIRATION_MINUTES * 60:
136
  expired_keys.append(key)
137
-
138
- # ลบ key ที่หมดอายุ
139
  for key in expired_keys:
140
  del results_cache[key]
141
  print(f"🧹 Cache expired and removed for key: {key}")
142
-
143
- time.sleep(60) # ตรวจสอบทุกๆ 60 วินาที
 
144
 
145
  @app.on_event("startup")
146
  async def startup_event():
147
- """
148
- เริ่ม Background Thread สำหรับทำความสะอาด Cache เมื่อแอปเริ่มทำงาน
149
- """
150
  cleanup_thread = threading.Thread(target=cleanup_expired_cache, daemon=True)
151
  cleanup_thread.start()
152
  print("🗑️ Cache cleanup task started.")
153
 
154
- SYMPTOM_MAP = {
155
- "noSymptoms": "ไม่มีอาการ", "drinkAlcohol": "ดื่มเหล้า", "smoking": "สูบบุหรี่",
156
- "chewBetelNut": "เคี้ยวหมาก", "eatSpicyFood": "กินเผ็ดแสบ", "wipeOff": "เช็ดออกได้",
157
- "alwaysHurts": "เจ็บเมื่อโดนแผล"
158
- }
 
159
  def process_with_ai_model(image_path: str, prompt_text: str):
160
- try:
161
- image_pil = Image.open(image_path)
162
- image_pil = ImageOps.exif_transpose(image_pil)
163
- image_pil = image_pil.convert("RGB")
164
- image_tensor = transform(image_pil).unsqueeze(0).to(device)
165
- enc = tokenizer(prompt_text, return_tensors="pt", padding="max_length",
166
- truncation=True, max_length=128)
167
- ids, mask = enc["input_ids"].to(device), enc["attention_mask"].to(device)
168
- with torch.no_grad():
169
- fused_logits, _, _ = fusion_model(image_tensor, ids, mask)
170
- probs_fused = torch.softmax(fused_logits, dim=1)[0].cpu().numpy()
171
- pred_idx = int(np.argmax(probs_fused))
172
- pred_label = class_names[pred_idx]
173
- confidence = float(probs_fused[pred_idx]) * 100
174
- gradcam_overlay_np = compute_gradcam_overlay(image_pil, image_tensor, pred_idx)
175
- def image_to_base64(img):
176
- buffered = BytesIO()
177
- img.save(buffered, format="JPEG")
178
- return base64.b64encode(buffered.getvalue()).decode('utf-8')
179
- original_b64 = image_to_base64(image_pil)
180
- if gradcam_overlay_np is not None:
181
- gradcam_pil = Image.fromarray(gradcam_overlay_np)
182
- gradcam_b64 = image_to_base64(gradcam_pil)
183
- else:
184
- gradcam_b64 = original_b64
185
- return original_b64, gradcam_b64, pred_label, f"{confidence:.2f}"
186
- except Exception as e:
187
- print(f"❌ Error during AI processing: {e}")
188
- return None, None, "Error", "0.00"
189
-
190
- @app.get("/", response_class=RedirectResponse)
191
- async def root():
192
- return RedirectResponse(url="/detect")
193
- @app.get("/detect", response_class=HTMLResponse)
194
- async def show_upload_form(request: Request):
195
- return templates.TemplateResponse("detect.html", {"request": request})
 
 
 
 
 
 
 
 
 
 
 
196
 
197
  @app.post("/uploaded")
198
  async def handle_upload(
199
  request: Request,
200
  file: UploadFile = File(...),
201
  checkboxes: List[str] = Form([]),
202
- symptom_text: str = Form("")
 
203
  ):
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
204
  temp_filepath = os.path.join("uploads", f"{uuid.uuid4()}_{file.filename}")
205
  with open(temp_filepath, "wb") as buffer:
206
  shutil.copyfileobj(file.file, buffer)
 
 
 
 
 
 
 
207
  final_prompt_parts = []
208
  selected_symptoms_thai = {SYMPTOM_MAP.get(cb) for cb in checkboxes if SYMPTOM_MAP.get(cb)}
 
209
  if "ไม่มีอาการ" in selected_symptoms_thai:
210
  symptoms_group = {"เจ็บเมื่อโดนแผล", "กินเผ็ดแสบ"}
211
  lifestyles_group = {"ดื่มเหล้า", "สูบบุหรี่", "เคี้ยวหมาก"}
212
  patterns_group = {"เช็ดออกได้"}
213
  special_group = {"ไม่มีอาการ"}
 
214
  final_selected = (selected_symptoms_thai - symptoms_group) | \
215
  (selected_symptoms_thai & (lifestyles_group | patterns_group | special_group))
 
216
  final_prompt_parts.append(" ".join(sorted(list(final_selected))))
 
217
  elif selected_symptoms_thai:
218
  final_prompt_parts.append(" ".join(sorted(list(selected_symptoms_thai))))
 
219
  if symptom_text and symptom_text.strip():
220
  final_prompt_parts.append(symptom_text.strip())
 
221
  final_prompt = "; ".join(final_prompt_parts) if final_prompt_parts else "ไม่มีอาการ"
 
222
  image_b64, gradcam_b64, name_out, eva_output = process_with_ai_model(
223
  image_path=temp_filepath, prompt_text=final_prompt
224
  )
 
225
  os.remove(temp_filepath)
 
226
  result_id = str(uuid.uuid4())
 
227
  result_data = {
228
- "image_b64_data": image_b64, "gradcam_b64_data": gradcam_b64,
229
- "name_out": name_out, "eva_output": eva_output,
 
 
230
  }
 
231
  with cache_lock:
232
  results_cache[result_id] = {
233
  "data": result_data,
234
- "created_at": time.time()
235
  }
236
 
237
- results_url = request.url_for('show_results', result_id=result_id)
238
  return RedirectResponse(url=results_url, status_code=303)
239
 
 
 
240
  @app.get("/results/{result_id}", response_class=HTMLResponse)
241
  async def show_results(request: Request, result_id: str):
242
  with cache_lock:
243
  cached_item = results_cache.get(result_id)
 
244
  if not cached_item or (time.time() - cached_item["created_at"] > EXPIRATION_MINUTES * 60):
245
  if cached_item:
246
  with cache_lock:
@@ -250,6 +336,8 @@ async def show_results(request: Request, result_id: str):
250
  context = {"request": request, **cached_item["data"]}
251
  return templates.TemplateResponse("detect.html", context)
252
 
 
 
253
  if __name__ == "__main__":
254
- port = int(os.environ.get("PORT", 8000)) # ใช้ PORT ของ hosting ถ้ามี
255
- uvicorn.run(app, host="0.0.0.0", port=port)
 
21
  from fastapi.responses import HTMLResponse, RedirectResponse
22
  from fastapi.templating import Jinja2Templates
23
  from fastapi.staticfiles import StaticFiles
24
+ import requests
25
+ import uvicorn
26
 
27
  sys.path.append(os.path.abspath(os.path.dirname(__file__)))
28
  from models.densenet.preprocess.preprocessingwangchan import get_tokenizer, get_transforms
29
  from models.densenet.train_densenet_only import DenseNet121Classifier
30
  from models.densenet.train_text_only import TextClassifier
31
+
32
 
33
  HF_MODEL_URL = "https://huggingface.co/qqqqqqat/densenet_wangchan/resolve/main/best_fusion_densenet.pth"
34
  LOCAL_MODEL_PATH = "models/densenet/best_fusion_densenet.pth"
 
50
  FUSION_WEIGHTS_PATH = "models/densenet/best_fusion_densenet.pth"
51
  with open(FUSION_LABELMAP_PATH, "r", encoding="utf-8") as f:
52
  label_map = json.load(f)
53
+
54
  class_names = [label for label, _ in sorted(label_map.items(), key=lambda x: x[1])]
55
  NUM_CLASSES = len(class_names)
56
  device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
57
  print(f"🧠 Using device: {device}")
58
+
59
+
60
  class FusionDenseNetText(nn.Module):
61
  def __init__(self, num_classes, dropout=0.3):
62
  super().__init__()
 
72
  fused_in = torch.cat([logits_img, logits_txt], dim=1)
73
  fused_out = self.fusion(fused_in)
74
  return fused_out, logits_img, logits_txt
75
+
76
+
77
  print("🔄 Loading AI model...")
78
  fusion_model = FusionDenseNetText(num_classes=NUM_CLASSES).to(device)
79
+
80
  download_model_if_needed()
81
 
82
  fusion_model = FusionDenseNetText(num_classes=NUM_CLASSES).to(device)
 
84
  fusion_model.eval()
85
  print("✅ AI Model loaded successfully!")
86
  fusion_model.eval()
87
+
88
  tokenizer = get_tokenizer()
89
  transform = get_transforms((224, 224))
90
+
91
+
92
  def _find_last_conv2d(mod: torch.nn.Module):
93
  last = None
94
  for m in mod.modules():
95
  if isinstance(m, torch.nn.Conv2d): last = m
96
  return last
97
+
98
+
99
  def compute_gradcam_overlay(img_pil, image_tensor, target_class_idx):
100
  img_branch = fusion_model.image_model
101
  target_layer = _find_last_conv2d(img_branch)
102
  if target_layer is None: return None
103
+
104
  activations, gradients = [], []
105
+
106
  def fwd_hook(_m, _i, o): activations.append(o)
107
  def bwd_hook(_m, gin, gout): gradients.append(gout[0])
108
+
109
  h1 = target_layer.register_forward_hook(fwd_hook)
110
  h2 = target_layer.register_full_backward_hook(bwd_hook)
111
+
112
  try:
113
  img_branch.zero_grad()
114
  logits_img = img_branch(image_tensor)
115
  score = logits_img[0, target_class_idx]
116
  score.backward()
117
+
118
  act = activations[-1].detach()[0]
119
  grad = gradients[-1].detach()[0]
120
+
121
  weights = torch.mean(grad, dim=(1, 2))
122
  cam = torch.relu(torch.sum(weights[:, None, None] * act, dim=0))
123
+
124
  cam -= cam.min(); cam /= (cam.max() + 1e-8)
125
+
126
  cam_img = Image.fromarray((cam.cpu().numpy() * 255).astype(np.uint8)).resize(img_pil.size, Image.BILINEAR)
127
+
128
  cam_np = np.asarray(cam_img).astype(np.float32) / 255.0
129
  heatmap = cm.get_cmap("jet")(cam_np)[:, :, :3]
130
+
131
  img_np = np.asarray(img_pil.convert("RGB")).astype(np.float32) / 255.0
132
  overlay = (0.6 * img_np + 0.4 * heatmap)
133
+
134
  return np.clip(overlay * 255, 0, 255).astype(np.uint8)
135
+
136
  finally:
137
  h1.remove(); h2.remove(); img_branch.zero_grad()
138
 
 
142
  templates = Jinja2Templates(directory="templates")
143
  os.makedirs("uploads", exist_ok=True)
144
 
145
+
146
+ # ===== Cache Cleanup Thread =====
147
+ EXPIRATION_MINUTES = 10
148
+ results_cache = {}
149
+ cache_lock = threading.Lock()
150
+
151
 
152
  def cleanup_expired_cache():
 
 
 
153
  while True:
154
+ with cache_lock:
 
155
  expired_keys = []
156
  current_time = time.time()
157
  for key, value in results_cache.items():
158
  if current_time - value["created_at"] > EXPIRATION_MINUTES * 60:
159
  expired_keys.append(key)
160
+
 
161
  for key in expired_keys:
162
  del results_cache[key]
163
  print(f"🧹 Cache expired and removed for key: {key}")
164
+
165
+ time.sleep(60)
166
+
167
 
168
  @app.on_event("startup")
169
  async def startup_event():
 
 
 
170
  cleanup_thread = threading.Thread(target=cleanup_expired_cache, daemon=True)
171
  cleanup_thread.start()
172
  print("🗑️ Cache cleanup task started.")
173
 
174
+
175
+ # ===============================
176
+ # >>> TURNSTILE VERIFY <<<
177
+ # ===============================
178
+
179
+ TURNSTILE_SECRET = "0x4AAAAAACEfyIeAjGlYCXeasGsMxTuTlHU" ### (เพิ่ม)
180
  def process_with_ai_model(image_path: str, prompt_text: str):
181
+ """ประมวลผลภาพด้วย DenseNet + Text model และสร้างผลลัพธ์"""
182
+
183
+ # --- Load image ---
184
+ img_pil = Image.open(image_path).convert("RGB")
185
+ img_tensor = transform(img_pil).unsqueeze(0).to(device)
186
+
187
+ # --- Tokenize text prompt ---
188
+ encoding = tokenizer(
189
+ prompt_text,
190
+ padding="max_length",
191
+ truncation=True,
192
+ max_length=128,
193
+ return_tensors="pt"
194
+ )
195
+ input_ids = encoding["input_ids"].to(device)
196
+ attention_mask = encoding["attention_mask"].to(device)
197
+
198
+ # --- Forward pass ---
199
+ with torch.no_grad():
200
+ fused_out, logits_img, logits_txt = fusion_model(
201
+ img_tensor, input_ids, attention_mask
202
+ )
203
+ probs = torch.softmax(fused_out, dim=1)[0]
204
+ pred_idx = torch.argmax(probs).item()
205
+
206
+ predicted_label = class_names[pred_idx]
207
+ confidence = probs[pred_idx].item()
208
+
209
+ name_out = f"{predicted_label} ({confidence:.2f})"
210
+
211
+ # --- Compute GradCAM ---
212
+ overlay_np = compute_gradcam_overlay(img_pil, img_tensor, pred_idx)
213
+ if overlay_np is not None:
214
+ overlay_img = Image.fromarray(overlay_np)
215
+ buffered = BytesIO()
216
+ overlay_img.save(buffered, format="PNG")
217
+ gradcam_b64 = base64.b64encode(buffered.getvalue()).decode()
218
+ else:
219
+ gradcam_b64 = None
220
+
221
+ # --- Encode original image ---
222
+ buffer2 = BytesIO()
223
+ img_pil.save(buffer2, format="PNG")
224
+ image_b64 = base64.b64encode(buffer2.getvalue()).decode()
225
+
226
+ return image_b64, gradcam_b64, predicted_label, confidence
227
+
228
 
229
  @app.post("/uploaded")
230
  async def handle_upload(
231
  request: Request,
232
  file: UploadFile = File(...),
233
  checkboxes: List[str] = Form([]),
234
+ symptom_text: str = Form(""),
235
+ cf_turnstile_response: str = Form("") ### (เพิ่ม)
236
  ):
237
+
238
+ # ---------- (เพิ่ม) VERIFY TURNSTILE ---------
239
+
240
+ if not cf_turnstile_response:
241
+ return templates.TemplateResponse(
242
+ "detect.html",
243
+ {"request": request, "error": "Turnstile token missing"}
244
+ )
245
+
246
+ verify_resp = requests.post(
247
+ "https://challenges.cloudflare.com/turnstile/v0/siteverify",
248
+ data={
249
+ "secret": TURNSTILE_SECRET,
250
+ "response": cf_turnstile_response
251
+ }
252
+ )
253
+
254
+ cf_result = verify_resp.json()
255
+
256
+ if not cf_result.get("success"):
257
+ return templates.TemplateResponse(
258
+ "detect.html",
259
+ {"request": request, "error": "การยืนยันความปลอดภัยไม่สำเร็จ กรุณาลองใหม่อีกครั้ง"}
260
+ )
261
+
262
+ # ------------------------------------------------
263
+
264
+
265
+ # >>> โค้ดเดิมของคุณทั้งหมดด้านล่าง (ไม่แก้) <<<
266
+
267
  temp_filepath = os.path.join("uploads", f"{uuid.uuid4()}_{file.filename}")
268
  with open(temp_filepath, "wb") as buffer:
269
  shutil.copyfileobj(file.file, buffer)
270
+
271
+ SYMPTOM_MAP = {
272
+ "noSymptoms": "ไม่มีอาการ", "drinkAlcohol": "ดื่มเหล้า", "smoking": "สูบบุหรี่",
273
+ "chewBetelNut": "เคี้ยวหมาก", "eatSpicyFood": "กินเผ็ดแสบ", "wipeOff": "เช็ดออกได้",
274
+ "alwaysHurts": "เจ็บเมื่อโดนแผล"
275
+ }
276
+
277
  final_prompt_parts = []
278
  selected_symptoms_thai = {SYMPTOM_MAP.get(cb) for cb in checkboxes if SYMPTOM_MAP.get(cb)}
279
+
280
  if "ไม่มีอาการ" in selected_symptoms_thai:
281
  symptoms_group = {"เจ็บเมื่อโดนแผล", "กินเผ็ดแสบ"}
282
  lifestyles_group = {"ดื่มเหล้า", "สูบบุหรี่", "เคี้ยวหมาก"}
283
  patterns_group = {"เช็ดออกได้"}
284
  special_group = {"ไม่มีอาการ"}
285
+
286
  final_selected = (selected_symptoms_thai - symptoms_group) | \
287
  (selected_symptoms_thai & (lifestyles_group | patterns_group | special_group))
288
+
289
  final_prompt_parts.append(" ".join(sorted(list(final_selected))))
290
+
291
  elif selected_symptoms_thai:
292
  final_prompt_parts.append(" ".join(sorted(list(selected_symptoms_thai))))
293
+
294
  if symptom_text and symptom_text.strip():
295
  final_prompt_parts.append(symptom_text.strip())
296
+
297
  final_prompt = "; ".join(final_prompt_parts) if final_prompt_parts else "ไม่มีอาการ"
298
+
299
  image_b64, gradcam_b64, name_out, eva_output = process_with_ai_model(
300
  image_path=temp_filepath, prompt_text=final_prompt
301
  )
302
+
303
  os.remove(temp_filepath)
304
+
305
  result_id = str(uuid.uuid4())
306
+
307
  result_data = {
308
+ "image_b64_data": image_b64,
309
+ "gradcam_b64_data": gradcam_b64,
310
+ "name_out": name_out,
311
+ "eva_output": eva_output,
312
  }
313
+
314
  with cache_lock:
315
  results_cache[result_id] = {
316
  "data": result_data,
317
+ "created_at": time.time()
318
  }
319
 
320
+ results_url = request.url_for("show_results", result_id=result_id)
321
  return RedirectResponse(url=results_url, status_code=303)
322
 
323
+
324
+
325
  @app.get("/results/{result_id}", response_class=HTMLResponse)
326
  async def show_results(request: Request, result_id: str):
327
  with cache_lock:
328
  cached_item = results_cache.get(result_id)
329
+
330
  if not cached_item or (time.time() - cached_item["created_at"] > EXPIRATION_MINUTES * 60):
331
  if cached_item:
332
  with cache_lock:
 
336
  context = {"request": request, **cached_item["data"]}
337
  return templates.TemplateResponse("detect.html", context)
338
 
339
+
340
+
341
  if __name__ == "__main__":
342
+ port = int(os.environ.get("PORT", 8000))
343
+ uvicorn.run(app, host="0.0.0.0", port=port)