#!/usr/bin/env python3
import argparse, os, sys, warnings, subprocess, tempfile, shutil
 # Binarization mode (set from CLI): 'strict' (default) or 'easy'
BIN_MODE = "strict"
import numpy as np
from PIL import Image

def rgb2hsv(img):  # img: float32 in [0,1], shape (H,W,3)
    # vectorized RGB->HSV, returns [H,S,V] with H in [0,360)
    r, g, b = img[...,0], img[...,1], img[...,2]
    cmax = np.max(img, axis=-1)
    cmin = np.min(img, axis=-1)
    delta = cmax - cmin + 1e-8

    h = np.zeros_like(cmax)
    mask = (cmax == r)
    h[mask] = (60 * ((g[mask] - b[mask]) / delta[mask]) + 360) % 360
    mask = (cmax == g)
    h[mask] = (60 * ((b[mask] - r[mask]) / delta[mask]) + 120) % 360
    mask = (cmax == b)
    h[mask] = (60 * ((r[mask] - g[mask]) / delta[mask]) + 240) % 360

    s = np.where(cmax < 1e-6, 0.0, delta / (cmax + 1e-8))
    v = cmax
    return np.stack([h, s, v], axis=-1)

def make_masks(img):
    """
    输入: uint8 RGB 图 (H,W,3)
    输出: (gt_mask, pred_mask) → bool
    约定: 绿色为 Pred、浅灰为 GT；黑色描边与白底需排除
    """
    arr = np.asarray(img).astype(np.float32) / 255.0
    hsv = rgb2hsv(arr)
    H, S, V = hsv[...,0], hsv[...,1], hsv[...,2]

    # ---- 1) 背景/描边屏蔽 ----
    # 去掉接近白色背景 & 近黑色描边
    not_white = V < 0.95   # 小于这个亮度的才考虑
    not_black = V > 0.08   # 大于这个亮度，排除描边
    base_mask = not_white & not_black

    # ---- 2) Pred (有色填充，允许绿/橙/紫等) ----
    R, G, B = arr[...,0], arr[...,1], arr[...,2]
    # 近灰色：R≈G≈B
    near_gray = (np.abs(R-G) < 0.05) & (np.abs(R-B) < 0.05)

    # 一般彩色填充：饱和度较高，非白/非黑，且不是灰色
    colored = (S >= 0.25) & (V > 0.15) & (V < 0.97)
    pred_mask = base_mask & colored & (~near_gray)

    # 兼容旧版：绿色区间（某些图使用偏绿）
    greenish = (H >= 60) & (H <= 170) & (S >= 0.20)
    pred_mask |= base_mask & greenish

    # ---- 3) GT (浅灰填充；若只有描边，做膨胀补偿) ----
    gt_mask = base_mask & near_gray & (V > 0.30) & (V < 0.90)

    # Remove thin dashed grid lines by morphological opening (if scipy is available)
    try:
        from scipy.ndimage import binary_opening, binary_dilation, generate_binary_structure
        selem = generate_binary_structure(2, 1)  # 3x3 cross
        # 对 Pred 做两次 opening，切断细网格
        pred_mask = binary_opening(pred_mask, structure=selem)
        pred_mask = binary_opening(pred_mask, structure=selem)
        # 先尝试对 GT 做 opening（若只有描边会被吃掉）
        gt_clean = binary_opening(gt_mask, structure=selem)
        gt_clean = binary_opening(gt_clean, structure=selem)
        gt_mask = gt_clean
        # 如果 GT 面积过小，说明可能只有深灰描边；把描边作为 GT，经膨胀加粗
        if gt_mask.sum() < 200:
            stroke = base_mask & (S < 0.20) & (V > 0.15) & (V < 0.55)
            gt_mask |= binary_dilation(stroke, iterations=2)
    except Exception:
        pass

    # ---- 4) 小净化：去掉特别小的噪点（6连通像素阈值）----
    def remove_small(mask, min_pixels=80):
        # 简单形态学：用卷积统计邻域，移除孤立小点（不依赖外部库）
        from scipy.ndimage import label
        labeled, n = label(mask.astype(np.uint8))
        if n == 0: return mask
        counts = np.bincount(labeled.ravel())
        kill = np.where(counts < min_pixels)[0]
        keep = np.ones_like(mask, dtype=bool)
        for k in kill:
            if k == 0:  # 背景
                continue
            keep &= (labeled != k)
        return mask & keep

    try:
        pred_mask = remove_small(pred_mask)
        gt_mask   = remove_small(gt_mask)
    except Exception:
        # 没有 scipy 也不致命，直接用原 mask
        pass

    return gt_mask, pred_mask

def _auto_crop_axes(pil_img):
    """Auto-crop the inner 0..10 plot square to align GT/PRED renders.
    Heuristic: keep rows/cols that have enough non-white&non-black pixels
    (grid lines / shapes), then crop to their tightest box with a small pad.
    """
    arr = np.asarray(pil_img).astype(np.float32) / 255.0
    v = arr.max(axis=-1)
    not_white = v < 0.97
    not_black = v > 0.05
    mask = not_white & not_black

    h, w = mask.shape
    # count per row / column
    row_cnt = mask.sum(axis=1)
    col_cnt = mask.sum(axis=0)

    # thresholds: at least 3% of pixels on that row/col are non-white/non-black
    r_thr = max(5, int(0.03 * w))
    c_thr = max(5, int(0.03 * h))

    rows = np.where(row_cnt > r_thr)[0]
    cols = np.where(col_cnt > c_thr)[0]

    if rows.size < 2 or cols.size < 2:
        # fallback: return original image
        return pil_img

    y0, y1 = int(rows[0]), int(rows[-1])
    x0, x1 = int(cols[0]), int(cols[-1])

    # small padding (clip to image bounds)
    pad = 2
    y0 = max(0, y0 - pad)
    y1 = min(h, y1 + pad)
    x0 = max(0, x0 - pad)
    x1 = min(w, x1 + pad)

    cropped = pil_img.crop((x0, y0, x1, y1))
    return cropped

def detect_axes_bbox(pil_img):
    """Return (x0,y0,x1,y1) of inner axes box; fallback to full image if not found."""
    arr = np.asarray(pil_img).astype(np.float32) / 255.0
    v = arr.max(axis=-1)
    not_white = v < 0.97
    not_black = v > 0.05
    mask = not_white & not_black
    h, w = mask.shape
    row_cnt = mask.sum(axis=1)
    col_cnt = mask.sum(axis=0)
    r_thr = max(5, int(0.03 * w))
    c_thr = max(5, int(0.03 * h))
    rows = np.where(row_cnt > r_thr)[0]
    cols = np.where(col_cnt > c_thr)[0]
    if rows.size < 2 or cols.size < 2:
        return (0, 0, w, h)
    y0, y1 = int(rows[0]), int(rows[-1])
    x0, x1 = int(cols[0]), int(cols[-1])
    pad = 2
    y0 = max(0, y0 - pad); y1 = min(h, y1 + pad)
    x0 = max(0, x0 - pad); x1 = min(w, x1 + pad)
    return (x0, y0, x1, y1)

def crop_axes_and_match_size(img_a, img_b, target_size=None, return_meta=False):
    """
    将两张“单独渲染”的图片各自裁到内层坐标方框（0..10 的正方形），
    然后把两张裁好的图缩放到同一尺寸，保证像素对齐到同一坐标系。
    """
    ax_a = detect_axes_bbox(img_a)
    ax_b = detect_axes_bbox(img_b)
    bbox_a, bbox_b = ax_a, ax_b
    a_crop = img_a.crop(ax_a)
    b_crop = img_b.crop(ax_b)

    if target_size is None:
        target_size = a_crop.size  # 用 GT 的裁剪尺寸作为基准
    # 统一到同一尺寸
    if a_crop.size != target_size:
        a_crop = a_crop.resize(target_size, Image.BILINEAR)
    if b_crop.size != target_size:
        b_crop = b_crop.resize(target_size, Image.BILINEAR)
    if return_meta:
        return a_crop, b_crop, bbox_a, bbox_b, target_size
    return a_crop, b_crop

def binarize_any(img, already_cropped=False):
    """
    从“单图渲染的 GT / Pred ”中提取前景掩码。
    关键：去掉白底、黑色描边，并且**强力**去除灰色虚线网格。
    实现手段：
      1) 亮度阈值剔除白底与黑描边；
      2) 形态学 opening 连续两次，可靠切断 1–2 px 的细网格线；
      3) 连通域筛选，移除极小噪点；
    只依赖 SciPy 可选；若不可用则退化为亮度阈值。
    """
    # If we are in 'easy' mode, be extremely permissive: treat any non-white / non-black pixel as foreground.
    if globals().get("BIN_MODE", "strict") == "easy":
        if not already_cropped:
            img = _auto_crop_axes(img)
        arr = np.asarray(img).astype(np.float32) / 255.0
        v = arr.max(axis=-1)
        # keep almost everything that's not pure white or pure black
        mask = (v < 0.99) & (v > 0.01)
        return mask
    # 先自动裁掉外侧边距（如果还没裁过）
    if not already_cropped:
        img = _auto_crop_axes(img)
    arr = np.asarray(img).astype(np.float32) / 255.0
    v = arr.max(axis=-1)  # 粗略亮度
    not_white = v < 0.97
    not_black = v > 0.08
    mask = (not_white & not_black)

    try:
        # 形态学 opening 两次，切断 1–2 px 虚线/细线
        from scipy.ndimage import binary_opening, generate_binary_structure, label
        selem = generate_binary_structure(2, 1)  # 十字 3x3
        mask = binary_opening(mask, structure=selem)
        mask = binary_opening(mask, structure=selem)

        # 连通域面积阈值，去除小碎片（网格残留一般很小）
        labeled, n = label(mask.astype(np.uint8))
        if n > 0:
            counts = np.bincount(labeled.ravel())
            keep = np.zeros_like(mask, dtype=bool)
            # 严一些：至少保留面积≥150 px 的分量（可按需要调大/调小）
            for k, c in enumerate(counts):
                if k == 0:
                    continue
                if c >= 150:
                    keep |= (labeled == k)
            mask = keep
    except Exception:
        # 没有 SciPy 时，仍然返回阈值分割结果
        pass

    return mask

def compute_iou(m1, m2):
    inter = np.logical_and(m1, m2).sum(dtype=np.int64)
    union = np.logical_or(m1, m2).sum(dtype=np.int64)
    return (inter / union) if union > 0 else 0.0, int(inter), int(union)

def maybe_dilate(mask, iters=0):
    """Optionally dilate the boolean mask by `iters` pixels (if scipy is available)."""
    if iters is None or iters <= 0:
        return mask
    try:
        from scipy.ndimage import binary_dilation, generate_binary_structure
        selem = generate_binary_structure(2, 1)
        return binary_dilation(mask, structure=selem, iterations=int(iters))
    except Exception:
        warnings.warn("[WARN] SciPy not available for dilation; continuing without dilation.")
        return mask


def iou_from_two_renders(gt_path, pr_path, args):
    if not os.path.exists(gt_path):
        raise FileNotFoundError(gt_path)
    if not os.path.exists(pr_path):
        raise FileNotFoundError(pr_path)

    gt = Image.open(gt_path).convert("RGB")
    pr = Image.open(pr_path).convert("RGB")

    if args.force_flip_pred:
        pr = pr.transpose(Image.FLIP_TOP_BOTTOM)
        if args.verbose:
            print("[INFO] force_flip_pred: flipped PRED vertically", file=sys.stderr)

    out_dir = os.path.dirname(gt_path)
    os.makedirs(out_dir, exist_ok=True)

    if args.verbose:
        print(f"[INFO] opened GT size={gt.size}, Pred size={pr.size}", file=sys.stderr)

    if args.no_crop:
        if gt.size != pr.size:
            # pad smaller to match larger
            W = max(gt.size[0], pr.size[0]); H = max(gt.size[1], pr.size[1])
            def pad_to(im, W, H):
                bg = Image.new("RGB", (W,H), (255,255,255))
                bg.paste(im, (0,0))
                return bg
            gt = pad_to(gt, W, H)
            pr = pad_to(pr, W, H)
            if args.verbose:
                print(f"[INFO] no_crop: padded to common size {gt.size}", file=sys.stderr)
    else:
        target = (args.fixed_canvas, args.fixed_canvas) if getattr(args, "fixed_canvas", None) else None
        gt, pr, bbox_a, bbox_b, tgt = crop_axes_and_match_size(gt, pr, target_size=target, return_meta=True)
        if args.verbose:
            print(f"[INFO] crop_axes: gt_bbox={bbox_a}, pred_bbox={bbox_b}, target_size={tgt}", file=sys.stderr)
        if args.save_debug:
            gt.save(os.path.join(out_dir, "_gt_cropped.png"))
            pr.save(os.path.join(out_dir, "_pred_cropped.png"))

    gt_mask = binarize_any(gt, already_cropped=True)
    pred_mask = binarize_any(pr, already_cropped=True)

    # Optional dilation tolerance
    gt_mask_d = maybe_dilate(gt_mask, args.dilate)
    pred_mask_d = maybe_dilate(pred_mask, args.dilate)

    iou, inter, union = compute_iou(gt_mask_d, pred_mask_d)
    tried_flip = False
    if args.try_flip_pred and iou == 0.0:
        pr_flip = pr.transpose(Image.FLIP_TOP_BOTTOM)
        pred_mask_flip = binarize_any(pr_flip, already_cropped=True)
        pred_mask_flip = maybe_dilate(pred_mask_flip, args.dilate)
        iou_flip, inter_f, union_f = compute_iou(gt_mask_d, pred_mask_flip)
        tried_flip = True
        if iou_flip > iou:
            iou, inter, union = iou_flip, inter_f, union_f
            pred_mask_d = pred_mask_flip
            if args.verbose:
                print("[INFO] try_flip_pred: flipped version improved IoU", file=sys.stderr)
        else:
            if args.verbose:
                print("[INFO] try_flip_pred: flip did not improve IoU", file=sys.stderr)

    print(f"[IoU] {iou:.4f}  (intersection={inter} px, union={union} px)")
    if tried_flip and args.verbose:
        print("[INFO] (reported IoU already reflects the better of normal vs flipped PRED)", file=sys.stderr)

    if args.save_debug:
        Image.fromarray((gt_mask*255).astype(np.uint8)).save(os.path.join(out_dir, "_gt_mask.png"))
        Image.fromarray((pred_mask*255).astype(np.uint8)).save(os.path.join(out_dir, "_pred_mask.png"))
        mix = np.zeros((*gt_mask.shape, 3), dtype=np.uint8)
        mix[gt_mask]   = np.array([180,180,180], dtype=np.uint8)  # gray
        mix[pred_mask] = np.array([ 80,170, 90], dtype=np.uint8)  # green
        Image.fromarray(mix).save(os.path.join(out_dir, "_mix_vis.png"))
        overlay = np.zeros((*gt_mask.shape, 3), dtype=np.uint8)
        overlay[gt_mask] = np.array([180,180,180], dtype=np.uint8)
        overlay[pred_mask] = np.array([80,170,90], dtype=np.uint8)
        Image.fromarray(overlay).save(os.path.join(out_dir, "_overlay_masks.png"))

# --- Helper: render JSON with geometry.py to PNG for geometric IoU ---
def _render_with_geometry(geom_script, json_path, out_path, flip_y=False,
                          map_to_axes=False, anchor=None, size_mode=None):
    """
    Use geometry.py as a renderer to turn a JSON piece into a clean PNG render.
    This avoids any color-threshold heuristics and makes IoU robust.

    Args:
        geom_script: path to geometry.py
        json_path: annotation json
        out_path: png to save
        flip_y: whether to add --flip_y when rendering
        map_to_axes: if True, pass --map_to_axes to geometry.py
        anchor: 'corner' or 'centroid' (if provided)
        size_mode: 'keep' | 'fold' | 'auto' (if provided)
    """
    if not os.path.exists(geom_script):
        raise FileNotFoundError(f"[geom] geometry.py not found at: {geom_script}")
    if not os.path.exists(json_path):
        raise FileNotFoundError(f"[geom] JSON not found: {json_path}")
    cmd = [sys.executable, geom_script, "--ann", json_path, "--save_png", out_path, "--no_show"]
    if flip_y:
        cmd.append("--flip_y")
    if map_to_axes:
        cmd.append("--map_to_axes")
    if anchor in ("corner", "centroid"):
        cmd += ["--anchor", anchor]
    if size_mode in ("keep", "fold", "auto"):
        cmd += ["--size_into_scale", size_mode]
    # run and ensure the png appears
    proc = subprocess.run(cmd, stdout=subprocess.PIPE, stderr=subprocess.PIPE, text=True)
    if proc.returncode != 0:
        raise RuntimeError(f"[geom] rendering failed:\nCMD: {' '.join(cmd)}\nSTDERR:\n{proc.stderr}")
    if not os.path.exists(out_path):
        raise FileNotFoundError(f"[geom] expected output not found: {out_path}")

def compute_iou_from_json_pair(args, gt_json_path, pred_json_path, out_dir):
    """
    Helper to render two JSONs using geometry.py, compute IoU, and return (iou, inter, union).
    Handles temp dir logic, render, copying, binarization, dilation, and flipping as needed.
    """
    import tempfile
    import shutil
    import sys
    import os
    from PIL import Image
    global BIN_MODE
    prev_bin_mode = BIN_MODE
    BIN_MODE = "easy" if getattr(args, "geom_easy_threshold", False) else "strict"
    if args.keep_tmp_geom:
        tdir = tempfile.mkdtemp(prefix="geom_iou_")
        tmp_ctx = None
    else:
        tmp_ctx = tempfile.TemporaryDirectory()
        tdir = tmp_ctx.__enter__()
    try:
        gt_png = os.path.join(tdir, "gt_render.png")
        pr_png = os.path.join(tdir, "pred_render.png")
        _render_with_geometry(args.geom_script, gt_json_path, gt_png,
                              flip_y=args.flip_y_geom,
                              map_to_axes=args.map_to_axes_geom,
                              anchor=args.gt_anchor,
                              size_mode=args.size_mode_geom)
        _render_with_geometry(args.geom_script, pred_json_path, pr_png,
                              flip_y=args.flip_y_geom,
                              map_to_axes=args.map_to_axes_geom,
                              anchor=args.pred_anchor,
                              size_mode=args.size_mode_geom)
        if args.verbose:
            print(f"[INFO] geom-rendered -> GT:{gt_png}  PRED:{pr_png}", file=sys.stderr)
        # Optionally copy renders for inspection
        if args.save_geom_pngs_dir:
            os.makedirs(args.save_geom_pngs_dir, exist_ok=True)
            shutil.copy(gt_png, os.path.join(args.save_geom_pngs_dir, "gt_render.png"))
            shutil.copy(pr_png, os.path.join(args.save_geom_pngs_dir, "pred_render.png"))
        # Compute IoU inline, matching iou_from_two_renders logic
        gt = Image.open(gt_png).convert("RGB")
        pr = Image.open(pr_png).convert("RGB")
        if args.force_flip_pred:
            pr = pr.transpose(Image.FLIP_TOP_BOTTOM)
            if args.verbose:
                print("[INFO] force_flip_pred: flipped PRED vertically", file=sys.stderr)
        # Cropping/padding logic
        if args.no_crop:
            if gt.size != pr.size:
                W = max(gt.size[0], pr.size[0])
                H = max(gt.size[1], pr.size[1])
                def pad_to(im, W, H):
                    bg = Image.new("RGB", (W,H), (255,255,255))
                    bg.paste(im, (0,0))
                    return bg
                gt = pad_to(gt, W, H)
                pr = pad_to(pr, W, H)
                if args.verbose:
                    print(f"[INFO] no_crop: padded to common size {gt.size}", file=sys.stderr)
        else:
            target = (args.fixed_canvas, args.fixed_canvas) if getattr(args, "fixed_canvas", None) else None
            gt, pr, bbox_a, bbox_b, tgt = crop_axes_and_match_size(gt, pr, target_size=target, return_meta=True)
            if args.verbose:
                print(f"[INFO] crop_axes: gt_bbox={bbox_a}, pred_bbox={bbox_b}, target_size={tgt}", file=sys.stderr)
            if getattr(args, "save_debug", False):
                gt.save(os.path.join(out_dir, "_gt_cropped.png"))
                pr.save(os.path.join(out_dir, "_pred_cropped.png"))
        gt_mask = binarize_any(gt, already_cropped=True)
        pred_mask = binarize_any(pr, already_cropped=True)
        # Optional dilation tolerance
        gt_mask_d = maybe_dilate(gt_mask, args.dilate)
        pred_mask_d = maybe_dilate(pred_mask, args.dilate)
        iou, inter, union = compute_iou(gt_mask_d, pred_mask_d)
        tried_flip = False
        if args.try_flip_pred and iou == 0.0:
            pr_flip = pr.transpose(Image.FLIP_TOP_BOTTOM)
            pred_mask_flip = binarize_any(pr_flip, already_cropped=True)
            pred_mask_flip = maybe_dilate(pred_mask_flip, args.dilate)
            iou_flip, inter_f, union_f = compute_iou(gt_mask_d, pred_mask_flip)
            tried_flip = True
            if iou_flip > iou:
                iou, inter, union = iou_flip, inter_f, union_f
                pred_mask_d = pred_mask_flip
                if args.verbose:
                    print("[INFO] try_flip_pred: flipped version improved IoU", file=sys.stderr)
            else:
                if args.verbose:
                    print("[INFO] try_flip_pred: flip did not improve IoU", file=sys.stderr)
        print(f"[IoU] {iou:.4f}  (intersection={inter} px, union={union} px)")
        if tried_flip and getattr(args, "verbose", False):
            print("[INFO] (reported IoU already reflects the better of normal vs flipped PRED)", file=sys.stderr)
        if getattr(args, "save_debug", False):
            Image.fromarray((gt_mask*255).astype(np.uint8)).save(os.path.join(out_dir, "_gt_mask.png"))
            Image.fromarray((pred_mask*255).astype(np.uint8)).save(os.path.join(out_dir, "_pred_mask.png"))
            mix = np.zeros((*gt_mask.shape, 3), dtype=np.uint8)
            mix[gt_mask]   = np.array([180,180,180], dtype=np.uint8)  # gray
            mix[pred_mask] = np.array([ 80,170, 90], dtype=np.uint8)  # green
            Image.fromarray(mix).save(os.path.join(out_dir, "_mix_vis.png"))
            overlay = np.zeros((*gt_mask.shape, 3), dtype=np.uint8)
            overlay[gt_mask] = np.array([180,180,180], dtype=np.uint8)
            overlay[pred_mask] = np.array([80,170,90], dtype=np.uint8)
            Image.fromarray(overlay).save(os.path.join(out_dir, "_overlay_masks.png"))
    finally:
        BIN_MODE = prev_bin_mode
        if tmp_ctx is not None:
            tmp_ctx.__exit__(None, None, None)
    return (iou, inter, union)

def main():
    ap = argparse.ArgumentParser(
        description="Compute IoU in three ways: (A) from two single renders (gt_render.png & pred_render.png), (B) from an overlay image (gray vs green), or (G) geometrically by rendering two JSONs with geometry.py."
    )
    ap.add_argument("--overlay", help="Path to overlay image (PNG).")
    ap.add_argument("--gt_img",  help="Path to GT-only render image (PNG).")
    ap.add_argument("--pred_img", help="Path to Pred-only render image (PNG).")
    ap.add_argument("--save_debug", action="store_true", help="Save debug masks next to the image.")
    ap.add_argument("--no_crop", action="store_true", help="Do NOT auto-crop to axes; assume the two renders are already on the same canvas.")
    ap.add_argument("--verbose", action="store_true", help="Print debug info (bbox, sizes, mask pixels).")
    ap.add_argument("--force_flip_pred", action="store_true",
                    help="Force a vertical flip (top-bottom) on PRED image before IoU.")
    ap.add_argument("--try_flip_pred", action="store_true",
                    help="If IoU==0, automatically try vertical flip on PRED and report the better IoU.")
    ap.add_argument("--dilate", type=int, default=0,
                    help="Dilate both masks by N pixels before IoU (tolerance for 1-2px misalign).")
    ap.add_argument("--save_overlay_masks", action="store_true",
                    help="Additionally save a mask overlay visualization next to GT image.")
    ap.add_argument("--gt_json", help="Path to GT JSON (geometry mode).")
    ap.add_argument("--pred_json", help="Path to Pred JSON (geometry mode).")
    ap.add_argument("--geom_script", default=os.path.join(os.path.dirname(__file__), "geometry.py"),
                    help="Path to geometry.py to render JSON to PNG for geometric IoU.")
    ap.add_argument("--flip_y_geom", action="store_true",
                    help="Pass --flip_y to geometry.py when rendering both JSONs.")
    ap.add_argument("--map_to_axes_geom", action="store_true",
                    help="Pass --map_to_axes to geometry.py when rendering both JSONs.")
    ap.add_argument("--gt_anchor", choices=["corner","centroid"], default=None,
                    help="Anchor to use for GT JSON when rendering via geometry.py.")
    ap.add_argument("--pred_anchor", choices=["corner","centroid"], default=None,
                    help="Anchor to use for Pred JSON when rendering via geometry.py.")
    ap.add_argument("--size_mode_geom", choices=["keep","fold","auto"], default=None,
                    help="Size policy to pass to geometry.py (--size_into_scale).")
    ap.add_argument("--geom_easy_threshold", action="store_true",
                    help="Use an easy, permissive binarization for geometry renders (treat any non-white/non-black as foreground; no opening).")
    ap.add_argument("--keep_tmp_geom", action="store_true",
                    help="Keep the temporary geometry render PNGs for inspection.")
    ap.add_argument("--save_geom_pngs_dir", default=None,
                    help="If set, copy the intermediate GT/PRED geometry renders to this directory.")
    ap.add_argument("--fixed_canvas", type=int, default=None,
                    help="If set (e.g., 256), after cropping to the inner axes both GT and Pred renders are resized to NxN before mask/IoU.")
    ap.add_argument("--batch_my_run", default=None, help="Path to directory containing subfolders, each with a pred.json.")
    ap.add_argument("--gt_dir", default=None, help="Path to directory containing GT jsons (named <subfolder>.json).")
    ap.add_argument("--write_iou_txt", action="store_true", help="If set, write iou.txt in each subfolder with IoU details.")
    ap.add_argument("--summary_csv", default=None, help="If set, write CSV summary of results.")
    args = ap.parse_args()

    # Set global binarization mode based on CLI
    global BIN_MODE
    BIN_MODE = "easy" if getattr(args, "geom_easy_threshold", False) else "strict"

    # --- Batch branch: process multiple subfolders with pred.json and matching GT json ---
    if args.batch_my_run and args.gt_dir:
        import csv
        results = []
        batch_dir = args.batch_my_run
        gt_dir = args.gt_dir
        subdirs = sorted([d for d in os.listdir(batch_dir) if os.path.isdir(os.path.join(batch_dir, d))])
        for name in subdirs:
            folder = os.path.join(batch_dir, name)
            pred_json = os.path.join(folder, "pred.json")
            gt_json = os.path.join(gt_dir, f"{name}.json")
            if not (os.path.exists(pred_json) and os.path.exists(gt_json)):
                print(f"[WARN] Skipping {name}: missing pred.json or GT json", file=sys.stderr)
                continue
            # Save PNGs in folder if not otherwise specified
            prev_save_geom_pngs_dir = args.save_geom_pngs_dir
            if args.save_geom_pngs_dir is None:
                args.save_geom_pngs_dir = folder
            iou, inter, union = compute_iou_from_json_pair(args, gt_json, pred_json, folder)
            args.save_geom_pngs_dir = prev_save_geom_pngs_dir
            if args.write_iou_txt:
                txt_path = os.path.join(folder, "iou.txt")
                with open(txt_path, "w") as f:
                    f.write(f"IoU={iou:.6f}\n")
                    f.write(f"intersection={inter}\n")
                    f.write(f"union={union}\n")
            results.append((name, iou, inter, union))
        N = len(results)
        mean_iou = sum(x[1] for x in results) / N if N > 0 else 0.0
        print(f"Batch processed {N} items, mean IoU = {mean_iou:.6f}")
        if args.summary_csv:
            with open(args.summary_csv, "w", newline="") as f:
                writer = csv.writer(f)
                writer.writerow(["name","iou","intersection","union"])
                for name, iou, inter, union in results:
                    writer.writerow([name, f"{iou:.6f}", inter, union])
                writer.writerow(["__MEAN__", f"{mean_iou:.6f}", "", "", ""])
        return

    # --- Branch G: Geometric IoU via rendering both JSONs with geometry.py ---
    if args.gt_json and args.pred_json:
        out_dir = args.save_geom_pngs_dir or os.path.dirname(args.pred_json)
        compute_iou_from_json_pair(args, args.gt_json, args.pred_json, out_dir)
        return

    # Branch A: prefer two single renders if both provided
    if args.gt_img and args.pred_img:
        iou_from_two_renders(args.gt_img, args.pred_img, args)
        return

    # Branch B: fallback to overlay image parsing
    if args.overlay:
        if not os.path.exists(args.overlay):
            raise FileNotFoundError(args.overlay)
        img = Image.open(args.overlay).convert("RGB")
        gt_mask, pred_mask = make_masks(img)
        gt_mask = maybe_dilate(gt_mask, args.dilate)
        pred_mask = maybe_dilate(pred_mask, args.dilate)
        if args.verbose:
            print(f"[INFO] overlay-mode masks: gt_pixels={int(gt_mask.sum())}, pred_pixels={int(pred_mask.sum())}, shape={gt_mask.shape}", file=sys.stderr)
        iou, inter, union = compute_iou(gt_mask, pred_mask)
        print(f"[IoU] {iou:.4f}  (intersection={inter} px, union={union} px)")
        if args.save_debug:
            debug_dir = os.path.join(os.path.dirname(args.overlay), "_overlay_debug")
            os.makedirs(debug_dir, exist_ok=True)
            Image.fromarray((gt_mask*255).astype(np.uint8)).save(os.path.join(debug_dir, "gt_mask.png"))
            Image.fromarray((pred_mask*255).astype(np.uint8)).save(os.path.join(debug_dir, "pred_mask.png"))
            mix = np.zeros((*gt_mask.shape, 3), dtype=np.uint8)
            mix[gt_mask]   = np.array([180,180,180], dtype=np.uint8)
            mix[pred_mask] = np.array([ 80,170, 90], dtype=np.uint8)
            Image.fromarray(mix).save(os.path.join(debug_dir, "mix_vis.png"))
            print(f"[DEBUG] Saved masks to: {debug_dir}")
        return

    raise SystemExit("Please pass either --overlay or both --gt_img and --pred_img.")

if __name__ == "__main__":
    main()