maddigit's picture
Upload 27 files
ddbdbca verified
import os, json, csv, re, cv2, numpy as np, torch
from tqdm import tqdm
from editdistance import eval as edit_distance
from paddleocr import PaddleOCR
from datasets import load_dataset
# -------------------------------------------------------------------
# Paths
benchmark_repo = 'HuiZhang0812/CreatiDesign_benchmark' # huggingface repo of benchmark
benchmark = load_dataset(benchmark_repo, split="test")
root_gen = "outputs/CreatiDesign_benchmark/images"
save_root = root_gen.replace("images", "text_eval") # Output directory
os.makedirs(save_root, exist_ok=True)
DEBUG = True
# -------------------------------------------------------------------
# 1. OCR initialization (must be det=True)
ocr = PaddleOCR(det=True, rec=True, cls=False, use_angle_cls=False, lang='en')
# -------------------------------------------------------------------
device = "cuda" if torch.cuda.is_available() else "cpu"
# -------------------------------------------------------------------
# 3. Utility functions
def spatial_match_iou(det_res, gt_box, gt_text_fmt, iou_thr=0.5):
best_iou = 0.0
if det_res is None or len(det_res) == 0:
return best_iou
for item in det_res:
poly = item[0] # Detection box coordinates
txt_info = item[1] # Text information tuple
txt = txt_info[0] # Text content
if min_ned_substring(normalize_text(txt), gt_text_fmt) <= 0.7: # When calculating spatial, allow some degree of text error
iou_val = iou(quad2bbox(poly), gt_box)
best_iou = max(best_iou, iou_val)
return best_iou
# ① New tool: Minimum NED substring
def min_ned_substring(pred_fmt: str, tgt_fmt: str) -> float:
"""
Find a substring in pred_fmt with the same length as tgt_fmt, to minimize normalized edit distance
Return the minimum value (0 ~ 1)
"""
Lp, Lg = len(pred_fmt), len(tgt_fmt)
if Lg == 0:
return 0.0
if Lp < Lg: # If prediction string is shorter than target, calculate directly
return normalized_edit_distance(pred_fmt, tgt_fmt)
best = Lg # Maximum possible distance
for i in range(Lp - Lg + 1):
sub = pred_fmt[i:i+Lg]
d = edit_distance(sub, tgt_fmt)
if d < best:
best = d
if best == 0: # Early exit
break
return best / Lg # Normalize
def normalize_text(txt: str) -> str:
txt = txt.lower().replace(" ", "")
return re.sub(r"[^\w\s]", "", txt)
def normalized_edit_distance(pred: str, gt: str) -> float:
if not gt and not pred:
return 0.0
return edit_distance(pred, gt) / max(len(gt), len(pred))
def iou(boxA, boxB) -> float:
xA, yA = max(boxA[0], boxB[0]), max(boxA[1], boxB[1])
xB, yB = min(boxA[2], boxB[2]), min(boxA[3], boxB[3])
inter = max(0, xB - xA) * max(0, yB - yA)
if inter == 0:
return 0.0
areaA = (boxA[2]-boxA[0]) * (boxA[3]-boxA[1])
areaB = (boxB[2]-boxB[0]) * (boxB[3]-boxB[1])
return inter / (areaA + areaB - inter)
def quad2bbox(quad):
xs = [p[0] for p in quad]; ys = [p[1] for p in quad]
return [min(xs), min(ys), max(xs), max(ys)]
def crop(img, box):
h, w = img.shape[:2]
x1,y1,x2,y2 = map(int, box)
x1, y1 = max(0, x1), max(0, y1)
x2, y2 = min(w-1, x2), min(h-1, y2)
if x2 <= x1 or y2 <= y1:
return np.zeros((1,1,3), np.uint8)
return img[y1:y2, x1:x2]
# -------------------------------------------------------------------
# 4. Main loop
per_img_rows, all_sen_acc, all_ned, all_spatial, text_pairs = [], [], [], [], []
for case in tqdm(benchmark):
json_data = json.loads(case["metadata"])
case_info = json_data["img_info"]
case_id = case_info["img_id"]
gt_list = json_data["text_list"] # [{'text':..., 'bbox':[x1,y1,x2,y2]}, ...]
ori_w, ori_h = json_data["img_info"]["img_width"], json_data["img_info"]["img_height"]
img_path = os.path.join(root_gen, f"{case_id}.jpg")
img = cv2.imread(img_path)
H, W = img.shape[:2]
wr, hr = W / ori_w, H / ori_h # GT → Generated image scaling ratio
# ---------- 1) Full image OCR ----------
pred_lines = [] # Save OCR line text
ocr_res = ocr.ocr(img, cls=False)
if ocr_res and ocr_res[0]:
for quad, (txt, conf) in ocr_res[0]:
pred_lines.append(txt.strip())
# Concatenate into full text and normalize
pred_full_fmt = normalize_text(" ".join(pred_lines))
# ==========================================================
# ③ For each GT sentence, do "substring minimum NED" ---- no longer using IoU
img_sen_hits, img_neds, img_spatials = [], [], []
for t_idx, gt in enumerate(gt_list):
gt_text_orig = gt["text"].replace("\n", " ").strip()
gt_text_fmt = normalize_text(gt_text_orig)
# ---- Pure text matching ----
ned = min_ned_substring(pred_full_fmt, gt_text_fmt)
acc = 1.0 if ned == 0 else 0.0
img_sen_hits.append(acc)
img_neds.append(ned)
# ---------- Spatial consistency, using IOU ----------
gt_box = [v*wr if i%2==0 else v*hr for i,v in enumerate(gt["bbox"])]
det_res = ocr_res[0] if ocr_res else []
spatial_score = spatial_match_iou(det_res, gt_box, gt_text_fmt)
img_spatials.append(spatial_score) # Can be used directly or binarized
crop_box_int = list(map(int, gt_box))
img_crop = crop(img, crop_box_int)
if DEBUG:
# Save cropped image
img_crop_for_ocr_save_root = os.path.join(save_root, case_id)
os.makedirs(img_crop_for_ocr_save_root, exist_ok=True)
safe_text = gt_text_orig.replace('/', '_').replace('\\', '_')
safe_filename = f"{t_idx}_{safe_text}.jpg"
cv2.imwrite(os.path.join(img_crop_for_ocr_save_root, safe_filename), img_crop)
# --------- Record text pairs ----------
text_pairs.append({
"image_id" : case_id,
"text_id" : t_idx,
"gt_original" : gt_text_orig,
"gt_formatted" : gt_text_fmt
})
# ---------- 3) Summarize to image level ----------
sen_acc = float(np.mean(img_sen_hits))
ned = float(np.mean(img_neds))
spatial = float(np.mean(img_spatials))
per_img_rows.append([case_id, sen_acc, ned, spatial])
all_sen_acc.append(sen_acc)
all_ned.append(ned)
all_spatial.append(spatial)
# -------------------------------------------------------------------
# 5. Write results
result_root = root_gen.replace("images","")
csv_perimg = os.path.join(result_root, "text_results_per_image.csv")
with open(csv_perimg, "w", newline='', encoding="utf-8") as f:
w = csv.writer(f); w.writerow(["image_id","sen_acc","ned","score_spatial"]); w.writerows(per_img_rows)
with open(os.path.join(result_root, "text_overall.txt"), "w", encoding="utf-8") as f:
f.write(f"Images evaluated : {len(per_img_rows)}\n")
f.write(f"Global Sen ACC : {np.mean(all_sen_acc):.4f}\n")
f.write(f"Global NED : {np.mean(all_ned):.4f}\n")
f.write(f"Global Spatial : {np.mean(all_spatial):.4f}\n")
print("✓ Done! Results saved to", result_root)