import os, json, csv, re, cv2, numpy as np, torch from tqdm import tqdm from editdistance import eval as edit_distance from paddleocr import PaddleOCR from datasets import load_dataset # ------------------------------------------------------------------- # Paths benchmark_repo = 'HuiZhang0812/CreatiDesign_benchmark' # huggingface repo of benchmark benchmark = load_dataset(benchmark_repo, split="test") root_gen = "outputs/CreatiDesign_benchmark/images" save_root = root_gen.replace("images", "text_eval") # Output directory os.makedirs(save_root, exist_ok=True) DEBUG = True # ------------------------------------------------------------------- # 1. OCR initialization (must be det=True) ocr = PaddleOCR(det=True, rec=True, cls=False, use_angle_cls=False, lang='en') # ------------------------------------------------------------------- device = "cuda" if torch.cuda.is_available() else "cpu" # ------------------------------------------------------------------- # 3. Utility functions def spatial_match_iou(det_res, gt_box, gt_text_fmt, iou_thr=0.5): best_iou = 0.0 if det_res is None or len(det_res) == 0: return best_iou for item in det_res: poly = item[0] # Detection box coordinates txt_info = item[1] # Text information tuple txt = txt_info[0] # Text content if min_ned_substring(normalize_text(txt), gt_text_fmt) <= 0.7: # When calculating spatial, allow some degree of text error iou_val = iou(quad2bbox(poly), gt_box) best_iou = max(best_iou, iou_val) return best_iou # ① New tool: Minimum NED substring def min_ned_substring(pred_fmt: str, tgt_fmt: str) -> float: """ Find a substring in pred_fmt with the same length as tgt_fmt, to minimize normalized edit distance Return the minimum value (0 ~ 1) """ Lp, Lg = len(pred_fmt), len(tgt_fmt) if Lg == 0: return 0.0 if Lp < Lg: # If prediction string is shorter than target, calculate directly return normalized_edit_distance(pred_fmt, tgt_fmt) best = Lg # Maximum possible distance for i in range(Lp - Lg + 1): sub = pred_fmt[i:i+Lg] d = edit_distance(sub, tgt_fmt) if d < best: best = d if best == 0: # Early exit break return best / Lg # Normalize def normalize_text(txt: str) -> str: txt = txt.lower().replace(" ", "") return re.sub(r"[^\w\s]", "", txt) def normalized_edit_distance(pred: str, gt: str) -> float: if not gt and not pred: return 0.0 return edit_distance(pred, gt) / max(len(gt), len(pred)) def iou(boxA, boxB) -> float: xA, yA = max(boxA[0], boxB[0]), max(boxA[1], boxB[1]) xB, yB = min(boxA[2], boxB[2]), min(boxA[3], boxB[3]) inter = max(0, xB - xA) * max(0, yB - yA) if inter == 0: return 0.0 areaA = (boxA[2]-boxA[0]) * (boxA[3]-boxA[1]) areaB = (boxB[2]-boxB[0]) * (boxB[3]-boxB[1]) return inter / (areaA + areaB - inter) def quad2bbox(quad): xs = [p[0] for p in quad]; ys = [p[1] for p in quad] return [min(xs), min(ys), max(xs), max(ys)] def crop(img, box): h, w = img.shape[:2] x1,y1,x2,y2 = map(int, box) x1, y1 = max(0, x1), max(0, y1) x2, y2 = min(w-1, x2), min(h-1, y2) if x2 <= x1 or y2 <= y1: return np.zeros((1,1,3), np.uint8) return img[y1:y2, x1:x2] # ------------------------------------------------------------------- # 4. Main loop per_img_rows, all_sen_acc, all_ned, all_spatial, text_pairs = [], [], [], [], [] for case in tqdm(benchmark): json_data = json.loads(case["metadata"]) case_info = json_data["img_info"] case_id = case_info["img_id"] gt_list = json_data["text_list"] # [{'text':..., 'bbox':[x1,y1,x2,y2]}, ...] ori_w, ori_h = json_data["img_info"]["img_width"], json_data["img_info"]["img_height"] img_path = os.path.join(root_gen, f"{case_id}.jpg") img = cv2.imread(img_path) H, W = img.shape[:2] wr, hr = W / ori_w, H / ori_h # GT → Generated image scaling ratio # ---------- 1) Full image OCR ---------- pred_lines = [] # Save OCR line text ocr_res = ocr.ocr(img, cls=False) if ocr_res and ocr_res[0]: for quad, (txt, conf) in ocr_res[0]: pred_lines.append(txt.strip()) # Concatenate into full text and normalize pred_full_fmt = normalize_text(" ".join(pred_lines)) # ========================================================== # ③ For each GT sentence, do "substring minimum NED" ---- no longer using IoU img_sen_hits, img_neds, img_spatials = [], [], [] for t_idx, gt in enumerate(gt_list): gt_text_orig = gt["text"].replace("\n", " ").strip() gt_text_fmt = normalize_text(gt_text_orig) # ---- Pure text matching ---- ned = min_ned_substring(pred_full_fmt, gt_text_fmt) acc = 1.0 if ned == 0 else 0.0 img_sen_hits.append(acc) img_neds.append(ned) # ---------- Spatial consistency, using IOU ---------- gt_box = [v*wr if i%2==0 else v*hr for i,v in enumerate(gt["bbox"])] det_res = ocr_res[0] if ocr_res else [] spatial_score = spatial_match_iou(det_res, gt_box, gt_text_fmt) img_spatials.append(spatial_score) # Can be used directly or binarized crop_box_int = list(map(int, gt_box)) img_crop = crop(img, crop_box_int) if DEBUG: # Save cropped image img_crop_for_ocr_save_root = os.path.join(save_root, case_id) os.makedirs(img_crop_for_ocr_save_root, exist_ok=True) safe_text = gt_text_orig.replace('/', '_').replace('\\', '_') safe_filename = f"{t_idx}_{safe_text}.jpg" cv2.imwrite(os.path.join(img_crop_for_ocr_save_root, safe_filename), img_crop) # --------- Record text pairs ---------- text_pairs.append({ "image_id" : case_id, "text_id" : t_idx, "gt_original" : gt_text_orig, "gt_formatted" : gt_text_fmt }) # ---------- 3) Summarize to image level ---------- sen_acc = float(np.mean(img_sen_hits)) ned = float(np.mean(img_neds)) spatial = float(np.mean(img_spatials)) per_img_rows.append([case_id, sen_acc, ned, spatial]) all_sen_acc.append(sen_acc) all_ned.append(ned) all_spatial.append(spatial) # ------------------------------------------------------------------- # 5. Write results result_root = root_gen.replace("images","") csv_perimg = os.path.join(result_root, "text_results_per_image.csv") with open(csv_perimg, "w", newline='', encoding="utf-8") as f: w = csv.writer(f); w.writerow(["image_id","sen_acc","ned","score_spatial"]); w.writerows(per_img_rows) with open(os.path.join(result_root, "text_overall.txt"), "w", encoding="utf-8") as f: f.write(f"Images evaluated : {len(per_img_rows)}\n") f.write(f"Global Sen ACC : {np.mean(all_sen_acc):.4f}\n") f.write(f"Global NED : {np.mean(all_ned):.4f}\n") f.write(f"Global Spatial : {np.mean(all_spatial):.4f}\n") print("✓ Done! Results saved to", result_root)