# ICL_qwen_onepiece.py
import os, json, argparse, re, textwrap, math, datetime
import subprocess
import torch
import numpy as np
from PIL import Image, ImageDraw, ImageOps
from transformers import Qwen2_5_VLForConditionalGeneration, AutoProcessor
from qwen_vl_utils import process_vision_info

# plotting (optional)
try:
    import matplotlib.pyplot as plt
    _HAS_MPL = True
except Exception:
    _HAS_MPL = False

# optional external geometry backend (to ensure JSON→PNG rendering matches your earlier pipeline)
try:
    import geometry as GEOM
    _HAS_GEOM = True
except Exception:
    GEOM = None
    _HAS_GEOM = False

# --- Geometry backend capabilities debug helper ---
def _geom_caps():
    if not _HAS_GEOM:
        return {}
    caps = {name: True for name in dir(GEOM)}
    # memoize string for quick debug
    try:
        GEOM.__CAPS_STR__ = ", ".join(sorted([k for k in caps.keys() if not k.startswith('_')]))
    except Exception:
        pass
    return caps
_GEOM_CAPS = _geom_caps()

# =========================
# PROMPTS
# =========================
INSTR = textwrap.dedent("""
You are now a tangram piece extractor.
Task: From the given 10x10 grid plot image that contains exactly ONE piece, output STRICT JSON for that single piece only. The JSON MUST contain exactly these fields:
  type: one of [triangle, square, parallelogram]
  size: one of [large, medium, small, na]
  pos: [x,y] two floats (0..10 canvas grid, **pos is the polygon centroid / center-of-mass**)
  angle: float, MUST be chosen from {-135,-90,-45,0,45,90,135,180}
  flip: boolean
  scale: float
Return PURE JSON only, no extra text, no Markdown.
Rules:
- Coordinates follow the 0..10 axes in the image.
- The pos you output is the centroid of the shape.
- Use non-zero angles when needed to match slanted edges.
- Output exactly ONE JSON object (not a list).
- Do NOT use names like "small_triangle" or "big_triangle"; always use type in {triangle,square,parallelogram} and put size into the `size` field.
- Your optimization goal is to MAXIMIZE IoU with the black silhouette in the image.
- If uncertain, SNAP coordinates to .0 or .5 grid steps; use the allowed angle set exactly.
""").strip()

INSTR_POS_ONLY = textwrap.dedent("""
You are now a tangram piece locator.
Task: From the given 10x10 grid plot image that contains exactly ONE piece, output STRICT JSON for the piece POSITION only.
We will provide the piece meta info (type/size/angle, and possibly flip/scale). You MUST return only the position in this schema:
  {"pos": [x, y]}
Rules:
- Coordinates follow the 0..10 axes and **pos is the polygon centroid / center-of-mass**.
- Return PURE JSON only, no extra text, no Markdown.
- Output exactly one field: `pos` with two floats.
- If uncertain, SNAP coordinates to .0 or .5 grid steps.
""").strip()

CHEATSHEET = """
Example JSONs (strict schema):
{"type":"triangle","size":"small","pos":[5.2,3.8],"angle":45,"flip":false,"scale":1.0}
{"type":"square","size":"na","pos":[5.0,2.5],"angle":45,"flip":false,"scale":1.0}
{"type":"parallelogram","size":"na","pos":[6.5,5.0],"angle":0,"flip":false,"scale":1.0}
"""

# ---------- CLI ----------
def parse_args():
    ap = argparse.ArgumentParser()
    ap.add_argument("--onepiece_img_dir", required=True, help="目录：onepiece_images（PNG）")
    ap.add_argument("--onepiece_json_dir", required=True, help="目录：onepiece_from_svg（与 PNG 同名的 JSON）")
    ap.add_argument("--k_shots", type=int, default=10)
    ap.add_argument("--k_tests", type=int, default=5)
    ap.add_argument("--all_tests", action="store_true",
                    help="使用所有剩余样本作为测试集（在选出 k_shots 作为few-shot后，余下全部做测试）")
    ap.add_argument("--no_shuffle", action="store_true",
                    help="不打乱顺序；按文件名的字典序依次选取样本")
    ap.add_argument("--seed", type=int, default=2025)
    ap.add_argument("--teach", action="store_true")
    ap.add_argument("--max_new", type=int, default=512)
    ap.add_argument("--save_dir", default="runs/icl_onepiece", help="输出结果目录")
    ap.add_argument("--iterative", action="store_true", help="逐题公布答案+误差，自纠错评测")
    ap.add_argument("--iters", type=int, default=None, help="测试轮数。默认等于 --k_tests")
    ap.add_argument("--refine", action="store_true", help="启用几何局部搜索，对模型预测做离散网格/角度微调以提升IoU")
    ap.add_argument("--pos_only", action="store_true", help="仅预测位置：给模型提供 type/size/angle（可含flip/scale），模型只输出 pos")
    ap.add_argument("--map_to_bbox", action="store_true",
                    help="Map 0..10 grid to the PNG foreground bbox (default: off; use full-image 0..10 grid)")
    ap.add_argument("--map_to_axes", action="store_true",
                    help="Map 0..10 grid to the detected inner plotting axes (grid square) rather than full image or foreground bbox.")
    ap.add_argument("--anchor", type=str, default="centroid",
                    choices=["corner","centroid"],
                    help="Interpretation of pos: 'centroid' aligns polygon centroid; 'corner' aligns template origin (default: centroid).")
    ap.add_argument("--auto_align", action="store_true",
                    help="Per-sample alignment search over {map_to_bbox in [False,True]} × {anchor in ['corner','centroid']} to maximize upper-bound IoU (GT-JSON vs PNG).")
    ap.add_argument("--use_geometry_backend", action="store_true", default=True,
                help="Use external geometry backend for JSON→mask rendering if available (module `geometry`). (default: on)")
    ap.add_argument("--diagnose", action="store_true",
                    help="输出额外诊断：分离模型预测/坐标映射/识别误差")
    ap.add_argument("--silhouette_mode", type=str, default="auto",
                    choices=["auto", "bw", "dark", "light"],
                    help="轮廓提取方式: 'auto'（默认，彩色更稳），'bw'（黑白二值图），'dark'（灰度<阈值），'light'（灰度>阈值）")
    ap.add_argument("--size_into_scale", type=str, default="keep", choices=["keep","fold","auto"],
                    help="How to handle triangle size vs scale when rendering. 'keep': use JSON scale as-is; 'fold': fold size into scale; 'auto': choose per-sample by maximizing upper-bound IoU against PNG silhouette.")
    ap.add_argument("--iou_mode", type=str, default="json",
                    choices=["json", "png"],
                    help="如何计算 IoU：'json' 用 GT-JSON 与 Pred-JSON 在同一评测框架栅格化后比较（与几何叠加图一致）；'png' 用 Pred-JSON 与 PNG 轮廓比较（旧逻辑）。")
    ap.add_argument("--prealign_gt", action="store_true",
                    help="Pre-transform each GT JSON's pos into the user-selected (map_to_bbox, anchor) frame while keeping size/angle/flip/scale unchanged; use this both for few-shot examples and evaluation.")
    ap.add_argument("--sanity_png", type=str, default=None,
                    help="单对自检：指定一张 PNG 路径，与 --sanity_json 搭配使用")
    ap.add_argument("--sanity_json", type=str, default=None,
                    help="单对自检：指定与该 PNG 对应的 JSON 路径（原始 from SVG 的 JSON）")
    ap.add_argument("--calibrate", action="store_true",
                    help="Per-sample calibration against PNG silhouette: choose size policy (keep/fold) and triangle template variant to maximize upper-bound IoU in the chosen evaluation frame.")
    ap.add_argument("--tri_variants", type=int, default=2, choices=[1,2],
                    help="How many base triangle template variants to consider during calibration (1: right angle at origin; 2: also try unit square's top-right).")
    ap.add_argument("--verbose_calib", action="store_true",
                    help="Print detailed calibration choices per sample.")
    return ap.parse_args()
# --- Debug panel helper ---
def save_debug_panel_png(rgb, gt_mask, pred_mask, out_path):
    try:
        import numpy as _np
        from PIL import Image as _Image
        h, w = gt_mask.shape
        # stack: [GT mask]*, [Pred mask]*, [overlay]*
        gt_rgb = _np.dstack([gt_mask*255]*3).astype('uint8')
        pr_rgb = _np.dstack([pred_mask*255]*3).astype('uint8')
        ov = rgb.copy()
        ov[_np.logical_and(pred_mask, ~gt_mask)] = [0,0,0]
        ov[_np.logical_and(gt_mask, ~pred_mask)] = [255,255,255]
        ov[_np.logical_and(pred_mask, gt_mask)]  = [128,128,128]
        panel = _np.concatenate([gt_rgb, pr_rgb, ov], axis=1)
        _Image.fromarray(panel).save(out_path)
    except Exception as _e:
        print(f"[DIAG] WARN: save_debug_panel_png failed: {_e}")


# --- External geometry renderer helper (CLI-style, mirrors your terminal commands) ---

def run_geometry_render(gt_json: str, pred_json: str, save_dir: str):
    """Call geometry.py three times to save GT, Pred, and Overlay PNGs.
    This mirrors the manual terminal commands the user ran earlier.
    """
    try:
        script_dir = os.path.dirname(os.path.abspath(__file__))
        geo_py = os.path.join(script_dir, "geometry.py")
        if not os.path.exists(geo_py):
            print(f"[GEOM] geometry.py not found at {geo_py}; skip external rendering.")
            return
        # 1) GT only
        gt_png = os.path.join(save_dir, "gt_render.png")
        subprocess.run([
            "python", geo_py,
            "--ann", gt_json,
            "--save_png", gt_png,
            "--no_show",
        ], check=False)
        # 2) Pred only
        pred_png = os.path.join(save_dir, "pred_render.png")
        subprocess.run([
            "python", geo_py,
            "--ann", pred_json,
            "--save_png", pred_png,
            "--no_show",
        ], check=False)
        # 3) Overlay GT (gray) + Pred (green)
        ov_png = os.path.join(save_dir, "overlay_json_vs_pred.png")
        subprocess.run([
            "python", geo_py,
            "--overlay_gt_json", gt_json,
            "--overlay_pred_json", pred_json,
            "--save_png", ov_png,
            "--no_show",
        ], check=False)
        print(f"[GEOM] saved: {os.path.basename(gt_png)}, {os.path.basename(pred_png)}, {os.path.basename(ov_png)}")
    except Exception as e:
        print(f"[GEOM] external geometry render failed: {e}")


ARGS = parse_args()
# --- Hard-lock "yesterday final" style unless explicitly disabled ---
# This ensures outputs match the preferred overlay & IoU definition:
#   • Frame: 0..10 axes square (map_to_axes=True)
#   • Anchor: centroid (pos is center-of-mass)
#   • IoU: JSON vs JSON in the same frame
#   • External geometry backend enabled for pixel overlays
if os.environ.get("YESTERDAY_OFF", "0") != "1":
    ARGS.map_to_axes = True
    ARGS.anchor = "centroid"
    ARGS.iou_mode = "json"
    ARGS.use_geometry_backend = True
print("[INFO] POS-ONLY mode:", "ON" if ARGS.pos_only else "OFF", " — will honor your --map_to_bbox and --anchor choices.")

# --- Hard locks in POS-ONLY mode ---
if ARGS.pos_only:
    if getattr(ARGS, "auto_align", False):
        print("[INFO] POS-ONLY: ignoring --auto_align (disabled by design).")
        ARGS.auto_align = False
    # Do NOT force size policy here; allow user to choose keep/fold/auto,
    # or use our new --calibrate switch below to decide per-sample.

# --- Clarify frame precedence ---
if getattr(ARGS, 'map_to_axes', False) and getattr(ARGS, 'map_to_bbox', False):
    print("[INFO] --map_to_axes is set; it overrides --map_to_bbox. Using axes frame for 0..10 mapping.")

# 结果输出目录（在 save_dir 下再建时间戳文件夹）
ROOT_SAVE_DIR = os.path.abspath(ARGS.save_dir)
RUN_STAMP = datetime.datetime.now().strftime("%Y%m%d_%H%M%S")
SAVE_DIR = os.path.join(ROOT_SAVE_DIR, RUN_STAMP)
os.makedirs(SAVE_DIR, exist_ok=True)
print(f"[OUT] results will be saved to: {SAVE_DIR}")
# --- 画出带坐标网格的叠加图（与用户期望一致） ---
def _poly_in_grid_coords(piece: dict, anchor: str):
    """Return polygon vertices in 0..10 grid coordinates (no pixel mapping)."""
    ptype = _canon_type(piece.get("type","triangle"))
    if ptype == "triangle":
        base = _poly_triangle(piece.get("size","na"))
    elif ptype == "square":
        base = _poly_square()
    else:
        base = _poly_parallelogram()
    verts = _apply_transform(
        base,
        piece.get("pos",[5.0,5.0]),
        piece.get("angle",0.0),
        piece.get("flip",False),
        piece.get("scale",1.0),
        px=1.0, offx=0.0, offy=0.0,
        anchor=anchor
    )
    return verts

def save_overlay_axes_with_grid(pred_piece: dict, gt_piece: dict, out_path: str, title: str = ""):
    """
    以 0..10 坐标系画出 Pred(绿色+黑描边) 与 GT(灰色+深灰描边) 的叠加，可视化风格与用户截图一致。
    """
    import matplotlib.pyplot as _plt
    from matplotlib.patches import Polygon as _Poly

    fig = _plt.figure(figsize=(6, 6), dpi=160)
    ax = fig.add_subplot(111)
    ax.set_xlim(0, 10)
    ax.set_ylim(0, 10)
    ax.set_aspect('equal', adjustable='box')
    ax.set_xticks(range(0, 11))
    ax.set_yticks(range(0, 11))
    ax.grid(True, linestyle=':', linewidth=0.8, alpha=0.6)
    if title:
        ax.set_title(title)

    # GT first
    gt_poly = _Poly(_poly_in_grid_coords(gt_piece, getattr(ARGS,'anchor','corner')),
                    closed=True, facecolor=(0.82,0.82,0.82,1.0), edgecolor=(0.25,0.25,0.25,1.0), linewidth=2.0)
    ax.add_patch(gt_poly)

    # Pred on top
    pr_poly = _Poly(_poly_in_grid_coords(pred_piece, getattr(ARGS,'anchor','corner')),
                    closed=True, facecolor=(0.32,0.63,0.39,1.0), edgecolor=(0,0,0,1.0), linewidth=2.5)
    ax.add_patch(pr_poly)

    ax.set_xlabel("")
    ax.set_ylabel("")
    fig.tight_layout()
    fig.savefig(out_path, dpi=160, bbox_inches="tight")
    _plt.close(fig)

# --- 单对自检 helper ---
def run_sanity_check(png_path: str, json_path: str):
    """
    对单对 (PNG, JSON) 做“上界自检”，输出：
      - D-UB_native：在 {full/bbox}×{corner/centroid} 下，用 GT-JSON 渲染与 PNG 的最大 IoU
      - EF-UB：在“当前评测框架”（map_to_axes / map_to_bbox / anchor / size_into_scale）下的上界
      - 并保存一张 overlay：灰(重合)/黑(多余)/白(缺失)
    """
    print("\n[SANITY] start ...")
    png_abs = _abspath(png_path); _must_exist(png_abs)
    js_abs  = _abspath(json_path); _must_exist(js_abs)
    print(f"[SANITY] png = {png_abs}")
    print(f"[SANITY] json = {js_abs}")

    rgb = np.array(Image.open(png_abs).convert("RGB"))
    H, W = rgb.shape[:2]
    gt_mask_png = load_silhouette(png_abs)

    # 选择 size 策略（按当前配置）
    size_mode_used = (ARGS.size_into_scale if getattr(ARGS, "size_into_scale", "keep") in ("keep","fold") else "keep")

    # 1) D-UB_native：自动挑选 full/bbox × corner/centroid 的最优（在固定 size 策略下）
    try:
        gt_meta_piece = _load_piece_raw(js_abs)
        use_bbox_nat, anchor_nat, dub_nat, bbox_nat = _pick_alignment_fixed_size(
            gt_meta_piece, gt_mask_png, W, H, size_mode=size_mode_used
        )
        print(f"[SANITY] D-UB_native = {dub_nat:.4f}  (native=({use_bbox_nat},{anchor_nat}))")
    except Exception as e:
        print(f"[SANITY][WARN] native upper bound failed: {e}")
        dub_nat, use_bbox_nat, anchor_nat, bbox_nat = 0.0, False, "corner", None

    # 2) EF-UB：在“当前评测框架”下（map_to_axes 优先，其次看 map_to_bbox/anchor）
    try:
        anchor_used = getattr(ARGS, 'anchor', 'corner')
        if getattr(ARGS, 'map_to_axes', False):
            axes_bbox = detect_axes_bbox(rgb)
            bbox_used = axes_bbox
            m = render_mask(_apply_size_policy(gt_meta_piece, size_mode_used), W, H, bbox=bbox_used, anchor=anchor_used)
            ef_ub = iou_score(m, gt_mask_png)
            frame_str = "axes"
        else:
            use_bbox = bool(ARGS.map_to_bbox)
            bbox_used = foreground_bbox(gt_mask_png) if use_bbox else None
            m = render_mask(_apply_size_policy(gt_meta_piece, size_mode_used), W, H, bbox=bbox_used, anchor=anchor_used)
            ef_ub = iou_score(m, gt_mask_png)
            frame_str = "bbox" if use_bbox else "full"
        print(f"[SANITY] EF-UB({frame_str},{anchor_used}, size_mode={size_mode_used}) = {ef_ub:.4f}")
    except Exception as e:
        print(f"[SANITY][WARN] fixed-frame upper bound failed: {e}")
        ef_ub, frame_str = 0.0, "full"

    # 3) 保存 overlay（用 EF 框架渲染 GT-JSON 与 PNG 的重叠情况）
    try:
        gt_eff = _apply_size_policy(gt_meta_piece, size_mode_used)
        pred_mask = render_mask(gt_eff, W, H, bbox=(axes_bbox if frame_str=="axes" else (foreground_bbox(gt_mask_png) if frame_str=="bbox" else None)), anchor=anchor_used)
        ov = rgb.copy()
        ov[np.logical_and(pred_mask, ~gt_mask_png)] = [0,0,0]       # 预测多余(黑)
        ov[np.logical_and(gt_mask_png, ~pred_mask)] = [255,255,255] # 预测不足(白)
        ov[np.logical_and(pred_mask, gt_mask_png)]  = [128,128,128] # 重叠(灰)
        out_overlay = os.path.join(SAVE_DIR, "sanity_overlay.png")
        Image.fromarray(ov).save(out_overlay)
        print(f"[SANITY] overlay saved -> {out_overlay}")
    except Exception as e:
        print(f"[SANITY][WARN] overlay failed: {e}")

    print("[SANITY] done.\n")


# ---------- Path helpers ----------
def _abspath(p):
    """Return an absolute path for p.
    If p is relative, try resolving against:
      1) current working directory (where you ran python), then
      2) the directory of this script file.
    Also expands '~'.
    """
    if not isinstance(p, str):
        p = str(p)
    p = os.path.expanduser(p)
    if os.path.isabs(p):
        return p
    # Try CWD first
    cand1 = os.path.abspath(os.path.join(os.getcwd(), p))
    if os.path.exists(cand1):
        return cand1
    # Then try relative to script file
    script_dir = os.path.dirname(os.path.abspath(__file__))
    cand2 = os.path.abspath(os.path.join(script_dir, p))
    if os.path.exists(cand2):
        return cand2
    # Fall back to CWD resolution (even if missing) so caller can report it
    return cand1

def _must_exist(p):
    if os.path.exists(p):
        return
    # If p was produced by _abspath, also compute the script-relative candidate for a clearer error
    script_dir = os.path.dirname(os.path.abspath(__file__))
    script_cand = os.path.abspath(os.path.join(script_dir, os.path.relpath(p, os.getcwd())))
    msg = (
        "\n[PATH ERROR] Required path not found:\n"
        f"  → {p}\n"
        "Troubleshooting:\n"
        "  1) Are you running the script from the project root (the folder that contains 'dataset/')?\n"
        "  2) If your data lives elsewhere, pass absolute paths, e.g.:\n"
        "     --onepiece_img_dir /full/path/to/onepiece_images \\\n"
        "     --onepiece_json_dir /full/path/to/onepiece_from_svg\n"
        f"  3) Also tried (script dir): {script_cand}\n"
        "  4) Quick check (bash): ls -al ./dataset | head\n"
    )
    raise FileNotFoundError(msg)

def _list_pairs(img_dir, json_dir):
    img_dir = _abspath(img_dir)
    json_dir = _abspath(json_dir)
    print(f"[PATH] images dir = {img_dir}")
    print(f"[PATH] json dir   = {json_dir}")
    _must_exist(img_dir); _must_exist(json_dir)
    pairs = []
    for fn in os.listdir(img_dir):
        if not fn.lower().endswith('.png'):
            continue
        stem = os.path.splitext(fn)[0]
        jp = os.path.join(json_dir, stem + '.json')
        ip = os.path.join(img_dir, fn)
        if os.path.exists(jp):
            pairs.append((ip, jp))
    return sorted(pairs)

# 发现用户只想做自检时，直接跑完退出（不进入大模型评测流程）
if ARGS.sanity_png and ARGS.sanity_json:
    run_sanity_check(ARGS.sanity_png, ARGS.sanity_json)
    import sys as _sys
    _sys.exit(0)

# ---------- JSON schema normalize ----------
def _coerce_to_schema(p: dict):
    if p is None:
        return None
    t = str(p.get("type", "triangle")).lower()
    size = p.get("size", p.get("sz", "na"))
    tri_map = {
        "small_triangle": "small",
        "medium_triangle": "medium",
        "big_triangle": "large",
        "large_triangle": "large",
        "triangle": None,
    }
    if t in tri_map:
        if tri_map[t] is not None:
            size = tri_map[t]
        t = "triangle"
    elif t in ("square", "box", "sq"):
        t = "square"
    elif t in ("parallelogram", "para", "rhombus", "diamond"):
        t = "parallelogram"
    else:
        t = "triangle"

    size = str(size).lower()
    if size in ("big", "large", "l"):
        size = "large"
    elif size in ("medium", "m"):
        size = "medium"
    elif size in ("small", "s"):
        size = "small"
    elif size in ("na", "none", ""):
        size = "small" if t == "triangle" else "na"

    try:
        angle = float(p.get("angle", 0.0))
    except Exception:
        angle = 0.0
    allowed = [-135.0, -90.0, -45.0, 0.0, 45.0, 90.0, 135.0, 180.0]
    angle = min(allowed, key=lambda a: abs(a - angle))

    pos = p.get("pos", [5.0, 5.0])
    try:
        x = float(pos[0]); y = float(pos[1])
    except Exception:
        x, y = 5.0, 5.0
    x = max(0.0, min(10.0, x))
    y = max(0.0, min(10.0, y))

    flip = bool(p.get("flip", False))
    try:
        scale = float(p.get("scale", 1.0))
    except Exception:
        scale = 1.0

    return {"type": t, "size": size, "pos": [x, y], "angle": angle, "flip": flip, "scale": scale}

def _load_piece_dict(jpath, coerce: bool = True):
    obj = json.load(open(jpath, 'r', encoding='utf-8'))
    if isinstance(obj, list):
        if len(obj) == 0:
            raise ValueError(f"empty list in {jpath}")
        obj = obj[0]

    # normalize type/size naming, but allow keeping raw angle/scale when coerce=False
    raw_type = str(obj.get("type", "triangle")).lower()
    size_in = obj.get("size", obj.get("sz", "na"))

    tri_map = {
        "small_triangle": "small",
        "medium_triangle": "medium",
        "big_triangle": "large",
        "large_triangle": "large",
        "triangle": None,
    }
    if raw_type in tri_map:
        if tri_map[raw_type] is not None:
            size_in = tri_map[raw_type]
        t = "triangle"
    elif raw_type in ("square", "box", "sq"):
        t = "square"
    elif raw_type in ("parallelogram", "para", "rhombus", "diamond"):
        t = "parallelogram"
    else:
        t = "triangle"

    size = str(size_in).lower()
    if size in ("big", "large", "l"):
        size = "large"
    elif size in ("medium", "m"):
        size = "medium"
    elif size in ("small", "s"):
        size = "small"
    elif size in ("na", "none", "", None):
        size = "small" if t == "triangle" else "na"

    # position
    pos = obj.get("pos", [5.0, 5.0])
    try:
        x = float(pos[0]); y = float(pos[1])
    except Exception:
        x, y = 5.0, 5.0
    x = max(0.0, min(10.0, x))
    y = max(0.0, min(10.0, y))

    # orientation/scale (raw)
    try:
        ang = float(obj.get("angle", 0.0))
    except Exception:
        ang = 0.0
    flip = bool(obj.get("flip", False))
    try:
        scale = float(obj.get("scale", 1.0))
    except Exception:
        scale = 1.0

    piece = {"type": t, "size": size, "pos": [x, y], "angle": ang, "flip": flip, "scale": scale}
    # coerce=True: snap/normalize for model consumption
    # coerce=False: return raw fields exactly as in JSON (no size→scale folding)
    return _coerce_to_schema(piece) if coerce else piece


# --- Load JSON without any normalization or snapping ---

def _load_piece_raw(jpath):
    """
    Load the first JSON object from file as-is (no coercion/snap), but
    canonicalize the type and extract size keywords if the author encoded
    them inside the `type` string (e.g., "medium_triangle").
    We do **not** fold size into scale here; we leave `scale` untouched.
    """
    obj = json.load(open(jpath, 'r', encoding='utf-8'))
    if isinstance(obj, list):
        if not obj:
            raise ValueError(f"empty list in {jpath}")
        obj = obj[0]

    # raw fields
    t_raw = str(obj.get("type", "triangle"))
    size = obj.get("size", obj.get("sz", "na"))
    pos = obj.get("pos", [5.0, 5.0])
    angle = obj.get("angle", 0.0)
    flip = obj.get("flip", False)
    scale = obj.get("scale", 1.0)

    # If the type embeds size info (e.g., "small_triangle"), extract it.
    tl = t_raw.strip().lower()
    if "triang" in tl:
        t = "triangle"
        # pull size keywords from the type name
        if any(k in tl for k in ["small_", "_small"]):
            size_inferred = "small"
        elif any(k in tl for k in ["medium_", "_medium"]):
            size_inferred = "medium"
        elif any(k in tl for k in ["big_", "large_", "_big", "_large"]):
            size_inferred = "large"
        else:
            size_inferred = None
        # prefer explicit `size` field when provided; otherwise use the inferred one
        if size in (None, "", "na") and size_inferred is not None:
            size = size_inferred
    elif "square" in tl or tl in ("sq", "box"):
        t = "square"
    elif any(k in tl for k in ["parallelogram", "parallel", "para", "rhomb", "diamond"]):
        t = "parallelogram"
    else:
        t = "triangle"

    # coerce basic numeric types only (no snapping)
    try:
        pos = [float(pos[0]), float(pos[1])]
    except Exception:
        pos = [5.0, 5.0]
    try:
        angle = float(angle)
    except Exception:
        pass
    try:
        scale = float(scale)
    except Exception:
        pass

    # normalize size tokens to {small, medium, large, na}
    sl = (str(size).lower() if size is not None else "na")
    if sl in ("l", "large", "big"):
        size = "large"
    elif sl in ("m", "medium"):
        size = "medium"
    elif sl in ("s", "small"):
        size = "small"
    elif sl in ("", "none", "na", None):
        size = "na"
    else:
        size = sl

    return {"type": t, "size": size, "pos": pos, "angle": angle, "flip": bool(flip), "scale": scale}

# --- Canonicalize type without touching size/angle/scale ---
def _canon_type(t: str) -> str:
    if not isinstance(t, str):
        return "triangle"
    tl = t.strip().lower()
    if "triang" in tl:
        return "triangle"
    if tl in ("sq", "box") or "square" in tl:
        return "square"
    if any(k in tl for k in ["parallelogram", "parallel", "para", "rhomb", "diamond"]):
        return "parallelogram"
    # fallback
    return "triangle"



# ---------- Geometry templates ----------
# --- Triangle base variants (grid units) ---
# v0: right angle at (0,0): [(0,0),(1,0),(0,1)]
# v1: right angle at (1,1): [(0,0),(1,0),(1,1)]  (a mirrored base; flip/angle can emulate, but some exporters effectively use this as base)
_TRI_VARIANTS = [
    [(0.0, 0.0), (1.0, 0.0), (0.0, 1.0)],
    [(0.0, 0.0), (1.0, 0.0), (1.0, 1.0)],
]
TRI_VARIANT_IDX = 0  # default: use first variant

def _poly_triangle(size):
    base_poly = _TRI_VARIANTS[TRI_VARIANT_IDX if 0 <= TRI_VARIANT_IDX < len(_TRI_VARIANTS) else 0]
    return base_poly

def _poly_square():
    return [(0.0, 0.0), (1.0, 0.0), (1.0, 1.0), (0.0, 1.0)]

def _poly_parallelogram():
    return [(0.0, 0.0), (1.0, 0.0), (1.5, 0.5), (0.5, 0.5)]

def _normalize_size_into_scale(piece: dict):
    """
    Fold 'size' multiplier into 'scale' to avoid double-scaling when the dataset JSON
    already encodes geometric size in 'scale'. After this, set size='na'.
    """
    if not isinstance(piece, dict):
        return piece
    p = dict(piece)
    t = str(p.get("type", "triangle")).lower()
    s = str(p.get("size", "na")).lower()
    sc = float(p.get("scale", 1.0))
    if t == "triangle":
        mult = {"large": 1.0, "medium": math.sqrt(2)/2.0, "small": 0.5}.get(s, 1.0)
        sc *= mult
        p["scale"] = float(sc)
        p["size"]  = "na"
    elif t in ("square", "parallelogram"):
        # 这些模板用单位尺寸，统一靠 scale 控制几何大小
        p["size"] = "na"
    return p

# --- Size policy helper ---
def _apply_size_policy(piece: dict, mode: str):
    """Return a (possibly) modified copy of piece according to size policy.
    mode in {"keep","fold"}.  ("auto" is decided elsewhere per-sample.)
    """
    if mode == "fold":
        return _normalize_size_into_scale(piece)
    # keep
    return dict(piece)

# --- Calibration for evaluation shapes ---
def _calibrate_eval_shapes(gt_piece_raw: dict, gt_mask_png: np.ndarray, W: int, H: int,
                           bbox_used, anchor_used: str, try_variants: int = 2,
                           respect_user_size_mode: str = None):
    """
    Decide (TRI_VARIANT_IDX, size_mode_used) to best match the PNG silhouette
    under the CURRENT evaluation frame (bbox_used, anchor_used).
    - gt_piece_raw: raw JSON object (no folding).
    - respect_user_size_mode: if 'keep' or 'fold', limit to that; if 'auto'/None, try both.
    Returns: (tri_idx, size_mode_str, best_iou)
    Side effect: sets global TRI_VARIANT_IDX to tri_idx.
    """
    global TRI_VARIANT_IDX
    best = (-1.0, 0, "keep")
    cand_sizes = ["keep", "fold"] if (respect_user_size_mode in (None, "auto")) else [respect_user_size_mode]
    num_variants = max(1, min(int(try_variants), len(_TRI_VARIANTS)))
    for tri_idx in range(num_variants):
        TRI_VARIANT_IDX = tri_idx
        for sz in cand_sizes:
            try:
                m = render_mask(_apply_size_policy(gt_piece_raw, sz), W, H, bbox=bbox_used, anchor=anchor_used)
                iou_u = iou_score(m, gt_mask_png)
            except Exception:
                iou_u = 0.0
            if iou_u > best[0] + 1e-6:
                best = (iou_u, tri_idx, sz)
    TRI_VARIANT_IDX = best[1]
    return best[1], best[2], best[0]

def _apply_transform(poly, pos, angle_deg, flip, scale, px, offx=0.0, offy=0.0, anchor="corner", invert_y=False):
    """
    Apply flip/scale/rotate and then map from **grid units** to **pixel coordinates**.
    If invert_y=True, interpret grid as y-up (0 at bottom) and convert to pixel y-down
    by replacing y -> (10 - y) before applying scale/offset.
    Anchor: 'corner' uses template origin; 'centroid' uses polygon centroid.
    """
    ax = math.radians(float(angle_deg))
    ca, sa = math.cos(ax), math.sin(ax)
    out = []
    # centroid anchoring
    cx = cy = 0.0
    if anchor == "centroid" and len(poly) > 0:
        cx = sum(p[0] for p in poly) / float(len(poly))
        cy = sum(p[1] for p in poly) / float(len(poly))
    for (x, y) in poly:
        x -= cx; y -= cy
        if flip:
            x = -x
        x *= float(scale)
        y *= float(scale)
        xr = x * ca - y * sa
        yr = x * sa + y * ca
        gx = xr + float(pos[0])
        gy = yr + float(pos[1])
        if invert_y:
            # 0..10 grid is y-up → convert to pixel y-down
            gy = 10.0 - gy
        X = gx * px + offx
        Y = gy * px + offy
        out.append((X, Y))
    return out

def _pieces_to_polys_in_pixels(pieces, w, h, bbox=None, anchor=None):
    if bbox is not None:
        x0, y0, x1, y1 = bbox
        bw = max(1, int(x1 - x0))
        bh = max(1, int(y1 - y0))
        px = min(bw, bh) / 10.0
        offx, offy = float(x0), float(y0)
    else:
        # Use min side for full-frame mapping to reduce over-scaling when axes occupy a sub-rectangle
        px = min(w, h) / 10.0
        offx = offy = 0.0
    polys = []
    anch = anchor if anchor is not None else getattr(ARGS, 'anchor', 'corner')
    for p in pieces:
        ptype = _canon_type(p.get("type", "triangle"))
        if ptype == "triangle":
            poly = _poly_triangle(p["size"])
        elif ptype == "square":
            poly = _poly_square()
        else:
            poly = _poly_parallelogram()
        poly_px = _apply_transform(
            poly, p["pos"], p["angle"], p["flip"], p["scale"],
            px, offx, offy, anchor=anch, invert_y=True
        )
        polys.append(poly_px)
    return polys


def _render_mask_native(pieces, w, h, bbox=None, anchor=None):
    if isinstance(pieces, dict):
        pieces = [pieces]
    img = Image.new("L", (w, h), 0)
    draw = ImageDraw.Draw(img)
    for poly in _pieces_to_polys_in_pixels(pieces, w, h, bbox=bbox, anchor=anchor):
        draw.polygon(poly, fill=1)
    return np.array(img, dtype=np.uint8)

# Thin wrapper to select backend
def render_mask(pieces, w, h, bbox=None, anchor=None):
    """
    Render mask for a piece or list of pieces.
    If --use_geometry_backend is set and external module `geometry` is available,
    delegate to GEOM.render_mask(pieces, w, h, bbox=bbox, anchor=anchor);
    otherwise fall back to the native implementation.
    """
    use_geom = getattr(ARGS, "use_geometry_backend", False) and _HAS_GEOM
    if use_geom:
        try:
            # Accept both dict and list[dict]
            pcs = pieces if isinstance(pieces, (list, tuple)) else [pieces]
            # 1) Direct mask renderer?
            if hasattr(GEOM, "render_mask"):
                return GEOM.render_mask(pcs, w, h, bbox=bbox, anchor=anchor)
            # 2) piece → mask
            if hasattr(GEOM, "piece_to_mask"):
                if len(pcs) == 1:
                    return GEOM.piece_to_mask(pcs[0], w, h, bbox=bbox, anchor=anchor)
                # multi-piece: alpha-composite
                import numpy as _np
                acc = _np.zeros((h, w), dtype=_np.uint8)
                for pp in pcs:
                    acc |= (GEOM.piece_to_mask(pp, w, h, bbox=bbox, anchor=anchor).astype(_np.uint8))
                return acc
            # 3) piece → polys → rasterize
            if hasattr(GEOM, "piece_to_polys"):
                polys = []
                for pp in pcs:
                    segs = GEOM.piece_to_polys(pp, w, h, bbox=bbox, anchor=anchor)
                    if isinstance(segs, (list, tuple)):
                        polys.extend(segs)
                # use local rasterizer for polygons
                from PIL import Image, ImageDraw
                import numpy as _np
                img = Image.new("L", (w, h), 0)
                drw = ImageDraw.Draw(img)
                for poly in polys:
                    # allow either iterable of (x,y) or flat list
                    if len(poly) and not isinstance(poly[0], (list, tuple)):
                        # flat -> pairs
                        it = iter(poly)
                        poly = list(zip(it, it))
                    drw.polygon([(float(x), float(y)) for (x, y) in poly], fill=1)
                return _np.array(img, dtype=_np.uint8)
            # 4) generic 'render' (common in geometry backends)
            if hasattr(GEOM, "render"):
                return GEOM.render(pcs, w, h, bbox=bbox, anchor=anchor)
            # 5) generic raster from params (polys_from_params style)
            if hasattr(GEOM, "polys_from_params"):
                polys = []
                for pp in pcs:
                    segs = GEOM.polys_from_params(pp, w, h, bbox=bbox, anchor=anchor)
                    polys.extend(segs if isinstance(segs, (list, tuple)) else [segs])
                from PIL import Image, ImageDraw
                import numpy as _np
                img = Image.new("L", (w, h), 0)
                drw = ImageDraw.Draw(img)
                for poly in polys:
                    drw.polygon([(float(x), float(y)) for (x, y) in poly], fill=1)
                return _np.array(img, dtype=_np.uint8)
            # If we got here, backend lacks known APIs
            print(f"[WARN] geometry backend present but no compatible API. Available: {getattr(GEOM,'__CAPS_STR__','?')}")
        except Exception as e:
            print(f"[WARN] geometry backend failed: {e}; falling back to native renderer.")
    return _render_mask_native(pieces, w, h, bbox=bbox, anchor=anchor)

def save_pred_json(piece_obj: dict, out_path: str):
    with open(out_path, "w", encoding="utf-8") as f:
        json.dump(piece_obj, f, ensure_ascii=False, indent=2)

def render_pred_image(piece_obj: dict, w: int, h: int, bbox, out_path: str, face=(255,0,0), anchor=None):
    img = Image.new("RGB", (w, h), (255,255,255))
    draw = ImageDraw.Draw(img)
    polys = _pieces_to_polys_in_pixels([piece_obj], w, h, bbox=bbox, anchor=anchor)
    for poly in polys:
        draw.polygon(poly, fill=face, outline=(0,0,0))
    img.save(out_path)


# --- Overlay helper using geometry backend for pixel-perfect overlays ---
def save_overlay_geom_pred_vs_gt(pred_piece: dict, gt_piece: dict, w: int, h: int, bbox, anchor: str, out_path: str):
    """
    Overlay GT and Pred using the geometry backend for pixel-accurate masks, if available.
    - GT: gray fill (210,210,210) with dark outline (60,60,60)
    - Pred: green fill (80,160,100) with black outline (0,0,0)
    If geometry backend is unavailable, falls back to save_overlay_pred_vs_gt.
    """
    if _HAS_GEOM:
        # Use geometry backend to render masks for both pieces
        try:
            # Render both as binary masks
            # Accept both dict and list, but we want one mask per piece
            gt_mask = None
            pred_mask = None
            if hasattr(GEOM, "piece_to_mask"):
                gt_mask = GEOM.piece_to_mask(gt_piece, w, h, bbox=bbox, anchor=anchor)
                pred_mask = GEOM.piece_to_mask(pred_piece, w, h, bbox=bbox, anchor=anchor)
            elif hasattr(GEOM, "render_mask"):
                gt_mask = GEOM.render_mask([gt_piece], w, h, bbox=bbox, anchor=anchor)
                pred_mask = GEOM.render_mask([pred_piece], w, h, bbox=bbox, anchor=anchor)
            else:
                # fallback to our rasterizer if geometry backend lacks mask API
                gt_mask = render_mask(gt_piece, w, h, bbox=bbox, anchor=anchor)
                pred_mask = render_mask(pred_piece, w, h, bbox=bbox, anchor=anchor)

            # Compose overlay image
            from PIL import ImageDraw
            import numpy as np
            img = Image.new("RGB", (w, h), (255, 255, 255))
            draw = ImageDraw.Draw(img)

            # Draw GT: gray fill, dark outline
            gt_mask_bool = gt_mask.astype(bool)
            gt_img = Image.new("L", (w, h), 0)
            gt_img_np = np.array(gt_mask_bool, dtype=np.uint8) * 255
            gt_img = Image.fromarray(gt_img_np, mode="L")
            # Fill GT area
            img.paste((210, 210, 210), mask=gt_img)
            # Outline GT
            # Find contours for outline
            try:
                from skimage import measure
                contours = measure.find_contours(gt_mask_bool.astype(float), 0.5)
                for contour in contours:
                    pts = [tuple(map(float, p[::-1])) for p in contour]  # (y,x) → (x,y)
                    if len(pts) >= 2:
                        draw.line(pts, fill=(60, 60, 60), width=2)
            except Exception:
                # fallback: draw polygon outline if possible
                pass

            # Draw Pred: green fill, black outline
            pred_mask_bool = pred_mask.astype(bool)
            pred_img = Image.new("L", (w, h), 0)
            pred_img_np = np.array(pred_mask_bool, dtype=np.uint8) * 255
            pred_img = Image.fromarray(pred_img_np, mode="L")
            img.paste((80, 160, 100), mask=pred_img)
            # Outline Pred
            try:
                from skimage import measure
                contours = measure.find_contours(pred_mask_bool.astype(float), 0.5)
                for contour in contours:
                    pts = [tuple(map(float, p[::-1])) for p in contour]
                    if len(pts) >= 2:
                        draw.line(pts, fill=(0, 0, 0), width=2)
            except Exception:
                pass

            img.save(out_path)
            return
        except Exception as e:
            print(f"[WARN] save_overlay_geom_pred_vs_gt failed: {e}; falling back to save_overlay_pred_vs_gt.")
    # fallback: use polygon rasterizer
    save_overlay_pred_vs_gt(pred_piece, gt_piece, w, h, bbox, anchor, out_path)

# --- Overlay helper: user-preferred style (GT gray with outline, Pred green with black outline) ---
def save_overlay_pred_vs_gt(pred_piece: dict, gt_piece: dict, w: int, h: int, bbox, anchor: str, out_path: str):
    """
    Draw an overlay like the user's preferred style:
      - Predicted piece: filled in green with black outline.
      - GT piece: gray translucent fill with a dark outline.
    Both are rendered in the SAME evaluation frame (bbox, anchor) (pixel-space).
    """
    # base canvas
    img = Image.new("RGB", (w, h), (255, 255, 255))
    draw = ImageDraw.Draw(img)

    # polygons in pixel coordinates
    pred_polys = _pieces_to_polys_in_pixels([pred_piece], w, h, bbox=bbox, anchor=anchor)
    gt_polys   = _pieces_to_polys_in_pixels([gt_piece],   w, h, bbox=bbox, anchor=anchor)

    # draw GT first (translucent gray fill)
    # use a light gray fill and a darker gray outline
    for poly in gt_polys:
        draw.polygon(poly, fill=(210, 210, 210), outline=(60, 60, 60))

    # draw PRED on top (green fill, black outline)
    for poly in pred_polys:
        draw.polygon(poly, fill=(80, 160, 100), outline=(0, 0, 0))

    img.save(out_path)

def load_silhouette(path, dark_thr=80, light_thr=220, margin_ratio=0.03):
    """
    Robustly extract a single-piece silhouette from the PNG.
    Works for:
      - black-on-white, white-on-black
      - COLORED pieces on light background (matplotlib-style grids)

    Strategy:
      1) Read RGB, estimate background color from a thin border frame.
      2) Build three candidate masks:
         a) DARK: gray < dark_thr
         b) LIGHT: gray > light_thr
         c) COLOR-DIFF: L2 distance to background color > color_thr
            (with a small saturation boost to prefer colored regions)
      3) Suppress outer frame (margin), then despeckle by requiring 4-neighborhood support.
      4) Choose the candidate with sane area and consistent with the background hypothesis;
         if multiple are sane, pick the one with median distance-to-bg highest.
      5) As a final safety, if all masks are tiny/empty, fall back to COLOR-DIFF and keep
         the largest connected region via a simple flood fill from coarse seeds.

    Returns a uint8 mask in {0,1}.
    """
    mode = getattr(ARGS, "silhouette_mode", "auto")
    rgb = np.array(Image.open(path).convert("RGB"))
    if mode != "auto":
        print(f"[SILH] mode={mode}")

    # === Fast path for explicitly black/white (or forced dark/light) silhouettes ===
    if mode in ("bw", "dark", "light"):
        h, w = rgb.shape[:2]
        gray_fast = (0.2989 * rgb[..., 0] + 0.5870 * rgb[..., 1] + 0.1140 * rgb[..., 2]).astype(np.float32)
        m = max(1, int(margin_ratio * min(h, w)))
        # remove outer frame margin (treat as background)
        gray_fast[:m, :] = 255
        gray_fast[-m:, :] = 255
        gray_fast[:, :m] = 255
        gray_fast[:, -m:] = 255
        if mode in ("bw", "dark"):
            mk = (gray_fast < float(dark_thr)).astype(np.uint8)
        else:  # "light"
            mk = (gray_fast > float(light_thr)).astype(np.uint8)

        # keep the largest connected component (simple 4-neighborhood flood fill sampled on a grid)
        visited = np.zeros_like(mk, dtype=bool)
        best = np.zeros_like(mk, dtype=bool)
        best_area = 0
        step = max(4, int(min(h, w) * 0.04))
        for y in range(m, h - m, step):
            for x in range(m, w - m, step):
                if mk[y, x] == 0 or visited[y, x]:
                    continue
                stack = [(y, x)]
                visited[y, x] = True
                blob = []
                area = 0
                while stack:
                    cy, cx = stack.pop()
                    blob.append((cy, cx))
                    area += 1
                    for ny, nx in ((cy - 1, cx), (cy + 1, cx), (cy, cx - 1), (cy, cx + 1)):
                        if 0 <= ny < h and 0 <= nx < w and mk[ny, nx] and not visited[ny, nx]:
                            visited[ny, nx] = True
                            stack.append((ny, nx))
                if area > best_area:
                    best_area = area
                    best[:] = 0
                    for cy, cx in blob:
                        best[cy, cx] = 1
        return best.astype(np.uint8)

    h, w = rgb.shape[:2]
    r = rgb[..., 0].astype(np.float32)
    g = rgb[..., 1].astype(np.float32)
    b = rgb[..., 2].astype(np.float32)
    gray = 0.2989 * r + 0.5870 * g + 0.1140 * b

    # --- 1) Background color from border strips ---
    m = max(1, int(margin_ratio * min(h, w)))
    border_mask = np.zeros((h, w), dtype=bool)
    border_mask[:m, :] = True; border_mask[-m:, :] = True
    border_mask[:, :m] = True; border_mask[:, -m:] = True
    bg_r = np.median(r[border_mask]); bg_g = np.median(g[border_mask]); bg_b = np.median(b[border_mask])
    bg = np.array([bg_r, bg_g, bg_b], dtype=np.float32)
    bg_is_dark = (0.2989*bg_r + 0.5870*bg_g + 0.1140*bg_b) < 128.0

    # --- 2) Candidate masks ---
    cand_dark  = (gray < float(dark_thr))
    cand_light = (gray > float(light_thr))

    # color distance (with saturation emphasis)
    rgb_stack = np.stack([r, g, b], axis=-1)
    diff = np.linalg.norm(rgb_stack - bg[None, None, :], axis=-1)
    # simple saturation proxy
    mx = np.max(rgb_stack, axis=-1); mn = np.min(rgb_stack, axis=-1)
    sat = (mx - mn)
    color_thr = max(20.0, np.percentile(diff, 85) * 0.6)  # adaptive-ish
    cand_color = (diff + 0.15 * sat) > color_thr

    # --- 3) Kill outer margin & despeckle (4-neighborhood)
    def _post(mask: np.ndarray) -> np.ndarray:
        mk = mask.copy()
        # remove outer frame
        mk[:m, :] = False; mk[-m:, :] = False
        mk[:, :m] = False; mk[:, -m:] = False
        if mk.any():
            up    = np.pad(mk[:-1, :], ((1, 0), (0, 0)), mode='constant')
            down  = np.pad(mk[1:,  :], ((0, 1), (0, 0)), mode='constant')
            left  = np.pad(mk[:, :-1], ((0, 0), (1, 0)), mode='constant')
            right = np.pad(mk[:, 1: ], ((0, 0), (0, 1)), mode='constant')
            neigh = (up | down | left | right)
            mk = mk & neigh
        return mk

    cand_dark  = _post(cand_dark)
    cand_light = _post(cand_light)
    cand_color = _post(cand_color)

    # Optionally strengthen forced branch in the color path
    if mode == "dark":
        return cand_dark.astype(np.uint8)
    if mode == "light":
        return cand_light.astype(np.uint8)

    # --- 4) Score candidates ---
    def _score(mk: np.ndarray) -> tuple:
        area = float(mk.sum())
        frac = area / float(max(1, h * w))
        sane = 0.0005 <= frac <= 0.5
        return sane, frac, area

    cands = [("dark",  cand_dark),
             ("light", cand_light),
             ("color", cand_color)]

    sane_list = []
    for name, mk in cands:
        sane, frac, area = _score(mk)
        if sane:
            # tie-breaker: prefer the one consistent with background hypothesis
            if name == "dark" and not bg_is_dark:
                bias = -0.01
            elif name == "light" and bg_is_dark:
                bias = -0.01
            else:
                bias = 0.0
            # also prefer larger median distance to bg on candidate pixels
            if mk.any():
                med_diff = float(np.median(diff[mk]))
            else:
                med_diff = 0.0
            sane_list.append((frac + 0.0*bias, med_diff, name, mk))

    if sane_list:
        sane_list.sort(key=lambda t: (t[0], t[1]))  # primarily by area fraction, then med diff
        # take the best (max) by sorting ascending then taking last
        _, _, best_name, mask = sane_list[-1]
        return mask.astype(np.uint8)

    # --- 5) Fallback: keep the largest connected blob from color mask ---
    mk = cand_color.copy()
    if not mk.any():
        return mk.astype(np.uint8)

    # coarse-grid flood fill to find a big blob without scipy: sample seeds on a grid
    visited = np.zeros_like(mk, dtype=bool)
    best_blob = np.zeros_like(mk, dtype=bool)
    best_area = 0
    step = max(4, int(min(h, w) * 0.04))
    for y in range(m, h-m, step):
        for x in range(m, w-m, step):
            if not mk[y, x] or visited[y, x]:
                continue
            # simple stack-based flood fill
            stack = [(y, x)]
            blob = []
            visited[y, x] = True
            while stack:
                cy, cx = stack.pop()
                blob.append((cy, cx))
                for ny, nx in ((cy-1, cx), (cy+1, cx), (cy, cx-1), (cy, cx+1)):
                    if 0 <= ny < h and 0 <= nx < w and mk[ny, nx] and not visited[ny, nx]:
                        visited[ny, nx] = True
                        stack.append((ny, nx))
            if len(blob) > best_area:
                best_area = len(blob)
                best_blob[:] = False
                for cy, cx in blob:
                    best_blob[cy, cx] = True

    return best_blob.astype(np.uint8)

def foreground_bbox(mask):
    ys, xs = np.where(mask > 0)
    if xs.size == 0 or ys.size == 0:
        h, w = mask.shape
        return (0, 0, w, h)
    return (int(xs.min()), int(ys.min()), int(xs.max()) + 1, int(ys.max()) + 1)


# --- Detect inner plotting axes square (0..10 grid) ---
def detect_axes_bbox(rgb):
    """
    Detect the inner plotting square (the 0..10 axes area) from a Matplotlib-like figure.
    Heuristic: threshold for light gray gridlines/axes, dilate via 8-neighborhood union, then take the tight box.
    Finally force it to be square (short side) and clamp to image.
    """
    h, w = rgb.shape[:2]
    r = rgb[..., 0].astype(np.float32)
    g = rgb[..., 1].astype(np.float32)
    b = rgb[..., 2].astype(np.float32)
    gray = 0.2989 * r + 0.5870 * g + 0.1140 * b

    gridish = (gray > 200) & (gray < 245)

    mk = gridish.astype(np.uint8)
    if mk.any():
        p = np.pad(mk, ((1, 1), (1, 1)), mode="constant")
        neigh = (
            p[0:-2, 0:-2] | p[0:-2, 1:-1] | p[0:-2, 2:] |
            p[1:-1, 0:-2] | p[1:-1, 1:-1] | p[1:-1, 2:] |
            p[2:  , 0:-2] | p[2:  , 1:-1] | p[2:  , 2:]
        )
        ys, xs = np.where(neigh > 0)
        if xs.size and ys.size:
            x0, y0, x1, y1 = int(xs.min()), int(ys.min()), int(xs.max()) + 1, int(ys.max()) + 1
        else:
            m = int(0.02 * min(w, h))
            x0 = y0 = m
            x1, y1 = w - m, h - m
    else:
        m = int(0.02 * min(w, h))
        x0 = y0 = m
        x1, y1 = w - m, h - m

    bw, bh = x1 - x0, y1 - y0
    side = max(1, min(bw, bh))
    x1 = x0 + side
    y1 = y0 + side
    if x1 > w:
        sh = x1 - w
        x0 -= sh; x1 -= sh
    if y1 > h:
        sh = y1 - h
        y0 -= sh; y1 -= sh
    x0 = max(0, x0); y0 = max(0, y0)
    x1 = min(w, x1); y1 = min(h, y1)
    return (int(x0), int(y0), int(x1), int(y1))


# --- Helper functions for frame/anchor transformation ---
def _grid_px_params(use_bbox, W, H, bbox):
    """Return (px, offx, offy) for the chosen grid frame."""
    if use_bbox and bbox is not None:
        x0, y0, x1, y1 = bbox
        bw = max(1, int(x1 - x0))
        bh = max(1, int(y1 - y0))
        px = min(bw, bh) / 10.0
        return float(px), float(x0), float(y0)
    # Use the smaller side to match the plotted 0..10 axes area scale better
    px = min(W, H) / 10.0
    return float(px), 0.0, 0.0

def _anchor_offset_grid(piece, anchor_name):
    """
    Compute, in GRID units, the offset vector from template-origin to the chosen anchor
    after applying flip/scale/rotation to the base polygon.
    - If anchor_name=='corner': offset is (0,0).
    - If anchor_name=='centroid': offset is the transformed centroid of base polygon.
    """
    if anchor_name == "corner":
        return (0.0, 0.0)
    ptype = piece.get("type", "triangle")
    size  = piece.get("size", "na")
    if ptype == "triangle":
        base = _poly_triangle(size if size in ("large","medium","small") else "small")
    elif ptype == "square":
        base = _poly_square()
    else:
        base = _poly_parallelogram()
    # centroid of base (in grid units)
    cx = sum(x for x, _ in base) / float(len(base))
    cy = sum(y for _, y in base) / float(len(base))
    # transform centroid like any vertex
    ang = math.radians(float(piece.get("angle", 0.0)))
    ca, sa = math.cos(ang), math.sin(ang)
    x, y = cx, cy
    if bool(piece.get("flip", False)):
        x = -x
    x *= float(piece.get("scale", 1.0))
    y *= float(piece.get("scale", 1.0))
    xr = x * ca - y * sa
    yr = x * sa + y * ca
    return (float(xr), float(yr))

def transform_pos_between_frames(piece, pos_xy, W, H,
                                 from_use_bbox, from_anchor,
                                 to_use_bbox, to_anchor,
                                 bbox=None, bbox_from=None, bbox_to=None):
    """
    Convert a position expressed in one frame/anchor to another, keeping the rendered shape
    at the SAME pixel location.
      - Frame = whether the 0..10 grid is mapped to full image or to a bbox (foreground/axes).
      - Anchor = whether pos refers to template-origin ('corner') or polygon centroid ('centroid').
    Supports different bounding boxes for source/target frames via `bbox_from` and `bbox_to`.
    If only `bbox` is given, it is used for both.
    """
    # Resolve per-frame bboxes
    bf = bbox_from if bbox_from is not None else bbox
    bt = bbox_to   if bbox_to   is not None else bbox

    # ---- Step A: from grid → pixels (reference point implied by from_anchor)
    from_px, from_offx, from_offy = _grid_px_params(from_use_bbox, W, H, bf)
    dx_from, dy_from = _anchor_offset_grid(piece, from_anchor)
    pos_origin_from = (float(pos_xy[0]) - dx_from, float(pos_xy[1]) - dy_from)
    X_pix = pos_origin_from[0] * from_px + from_offx
    Y_pix = pos_origin_from[1] * from_px + from_offy

    # ---- Step B: pixels → target grid (with target anchor)
    to_px, to_offx, to_offy = _grid_px_params(to_use_bbox, W, H, bt)
    pos_origin_to_x = (X_pix - to_offx) / max(1e-9, to_px)
    pos_origin_to_y = (Y_pix - to_offy) / max(1e-9, to_px)
    dx_to, dy_to = _anchor_offset_grid(piece, to_anchor)
    pos_anchor_to = (pos_origin_to_x + dx_to, pos_origin_to_y + dy_to)
    return [float(pos_anchor_to[0]), float(pos_anchor_to[1])]

def iou_score(pred, gt):
    pred_b = pred.astype(bool)
    gt_b = gt.astype(bool)
    inter = np.logical_and(pred_b, gt_b).sum()
    union = np.logical_or(pred_b, gt_b).sum()
    return inter / max(1, union)

def geometry_metrics(pred_mask, gt_mask):
    pred_b = pred_mask.astype(bool)
    gt_b = gt_mask.astype(bool)
    inter = np.logical_and(pred_b, gt_b).sum()
    union = np.logical_or(pred_b, gt_b).sum()
    iou = inter / max(1, union)
    overflow = np.logical_and(pred_b, ~gt_b).sum() / pred_b.size
    undercov = np.logical_and(gt_b, ~pred_b).sum() / pred_b.size
    return iou, overflow, undercov

# ---------- Geometric refinement ----------
ALLOWED_ANGLES = [-135.0, -90.0, -45.0, 0.0, 45.0, 90.0, 135.0, 180.0]

def _snap_xy(pos):
    x, y = float(pos[0]), float(pos[1])
    def snap(v): return round(v*2.0)/2.0
    return [max(0.0, min(10.0, snap(x))), max(0.0, min(10.0, snap(y)))]

def refine_piece_by_search(piece, gt_mask, w, h, bbox, max_radius=1.0, step=0.5, try_flip=True, lock_angle=False, lock_flip=False, anchor=None):
    base = dict(piece)
    best = dict(base)
    pred_mask = render_mask(best, w, h, bbox=bbox, anchor=anchor)
    inter = np.logical_and(pred_mask, gt_mask).sum()
    uni   = np.logical_or(pred_mask, gt_mask).sum()
    best_iou = float(inter/uni) if uni>0 else 0.0

    xs = np.arange(-max_radius, max_radius+1e-6, step)
    ys = np.arange(-max_radius, max_radius+1e-6, step)

    if lock_angle:
        angle_candidates = [float(base.get("angle", 0.0))]
    else:
        angle_candidates = (
            ALLOWED_ANGLES
            if base.get("angle") not in ALLOWED_ANGLES
            else [base["angle"]] + [a for a in ALLOWED_ANGLES if a != base["angle"]]
        )
    if lock_flip:
        flip_candidates = [bool(base.get("flip", False))]
    else:
        flip_candidates = [bool(base.get("flip", False))] + ([not bool(base.get("flip", False))] if try_flip else [])

    for da in angle_candidates:
        for f in flip_candidates:
            for dx in xs:
                for dy in ys:
                    p = dict(base)
                    p["angle"] = float(da)
                    p["flip"] = bool(f)
                    px, py = base["pos"]
                    p["pos"] = [max(0.0, min(10.0, px+dx)), max(0.0, min(10.0, py+dy))]
                    p["pos"] = _snap_xy(p["pos"])
                    m = render_mask(p, w, h, bbox=bbox, anchor=anchor)
                    inter = np.logical_and(m, gt_mask).sum()
                    uni   = np.logical_or(m, gt_mask).sum()
                    iou = float(inter/uni) if uni>0 else 0.0
                    if iou > best_iou:
                        best_iou, best = iou, p
    return best, best_iou

def _auto_pick_alignment(gt_piece, gt_mask_png, W, H):
    """Try (map_to_bbox, anchor, size_mode) where size_mode∈{"keep","fold"}.
       Return (use_bbox, anchor, best_iou, best_bbox, size_mode).
       Prefer bbox on ties, and prefer "fold" on ties (more consistent with many exporters).
    """
    best = (-1.0, True, "corner", None, "keep")
    png_bbox = foreground_bbox(gt_mask_png)
    for use_bbox in (True, False):
        bbox = png_bbox if use_bbox else None
        for anchor in ("corner", "centroid"):
            for size_mode in ("keep", "fold"):
                try:
                    piece_try = _apply_size_policy(gt_piece, size_mode)
                    m = render_mask(piece_try, W, H, bbox=bbox, anchor=anchor)
                    iou_u = iou_score(m, gt_mask_png)
                except Exception:
                    iou_u = 0.0
                prefer = (
                    (iou_u > best[0] + 1e-6) or
                    (abs(iou_u - best[0]) <= 1e-6 and use_bbox and not best[1]) or
                    (abs(iou_u - best[0]) <= 1e-6 and use_bbox == best[1] and size_mode == "fold" and best[4] != "fold") or
                    (abs(iou_u - best[0]) <= 1e-6 and anchor == "corner" and best[2] != "corner")
                )
                if prefer:
                    best = (iou_u, use_bbox, anchor, bbox, size_mode)
    return best[1], best[2], best[0], best[3], best[4]

# --- Helper for picking alignment under fixed size policy (keep) ---
def _pick_alignment_fixed_size(gt_piece, gt_mask_png, W, H, size_mode="keep"):
    """Pick (use_bbox, anchor) maximizing IoU under a fixed size policy (default keep).
    Returns (use_bbox, anchor, best_iou, best_bbox).
    """
    best = (-1.0, True, "corner", None)
    png_bbox = foreground_bbox(gt_mask_png)
    for use_bbox in (True, False):
        bbox = png_bbox if use_bbox else None
        for anchor in ("corner", "centroid"):
            try:
                m = render_mask(_apply_size_policy(gt_piece, size_mode), W, H, bbox=bbox, anchor=anchor)
                iou_u = iou_score(m, gt_mask_png)
            except Exception:
                iou_u = 0.0
            prefer = (
                (iou_u > best[0] + 1e-6) or
                (abs(iou_u - best[0]) <= 1e-6 and use_bbox and not best[1]) or
                (abs(iou_u - best[0]) <= 1e-6 and anchor == "corner" and best[2] != "corner")
            )
            if prefer:
                best = (iou_u, use_bbox, anchor, bbox)
    return best[1], best[2], best[0], best[3]

# ---------- Model ----------
print("[INFO] Loading Qwen2.5-VL...")
model = Qwen2_5_VLForConditionalGeneration.from_pretrained("Qwen/Qwen2.5-VL-3B-Instruct", device_map="cpu").eval()
processor = AutoProcessor.from_pretrained("Qwen/Qwen2.5-VL-3B-Instruct")

# =========================
# Few-shot sampling
# =========================
rng = np.random.default_rng(ARGS.seed)
ALL = _list_pairs(ARGS.onepiece_img_dir, ARGS.onepiece_json_dir)

# 按需是否打乱
if ARGS.no_shuffle:
    ALL_ordered = ALL  # 已经是按文件名排序后的顺序
else:
    perm = rng.permutation(len(ALL))
    ALL_ordered = [ALL[i] for i in perm]

# 切分 few-shot / test
if ARGS.all_tests:
    if len(ALL_ordered) < ARGS.k_shots:
        raise RuntimeError(f"样本不足：找到 {len(ALL_ordered)} 对，需要 >= {ARGS.k_shots}")
    SHOT_PAIRS = ALL_ordered[:ARGS.k_shots]
    TEST_PAIRS = ALL_ordered[ARGS.k_shots:]  # 其余全部作为测试
else:
    if len(ALL_ordered) < ARGS.k_shots + ARGS.k_tests:
        raise RuntimeError(f"样本不足：找到 {len(ALL_ordered)} 对，需要 >= {ARGS.k_shots + ARGS.k_tests}")
    SHOT_PAIRS = ALL_ordered[:ARGS.k_shots]
    TEST_PAIRS = ALL_ordered[ARGS.k_shots:ARGS.k_shots + ARGS.k_tests]

print("="*40)
print(f"[SAMPLE SUMMARY] seed={ARGS.seed}  few-shot={len(SHOT_PAIRS)}  tests={len(TEST_PAIRS)}")
print("- Few-shot examples (used for learning):")
for i,(ip,jp) in enumerate(SHOT_PAIRS,1):
    print(f"  FS#{i:02d}: image={os.path.basename(ip)}  json={os.path.basename(jp)}")
print("- Test examples (to evaluate):")
for i,(ip,jp) in enumerate(TEST_PAIRS,1):
    print(f"  TE#{i:02d}: image={os.path.basename(ip)}  json={os.path.basename(jp)}")
print("="*40)
print(f"[CONFIG] map_to_bbox={ARGS.map_to_bbox} (False → 0..10 maps to FULL image; True → maps to silhouette bbox)")
print(f"[CONFIG] anchor={getattr(ARGS,'anchor','corner')} (pos refers to {'centroid' if getattr(ARGS,'anchor','corner')=='centroid' else 'template corner'})")
print(f"[CONFIG] map_to_axes={getattr(ARGS,'map_to_axes', False)} (if True → 0..10 maps to detected inner plot square)")

# ---------- Build conversation ----------
messages = []
if ARGS.pos_only:
    messages.append({"role":"system","content":[{"type":"text","text": INSTR_POS_ONLY}]})
else:
    messages.append({"role":"system","content":[{"type":"text","text": INSTR + ("\n"+CHEATSHEET if ARGS.teach else "")} ]})

# Always add few-shot examples
for ip, jp in SHOT_PAIRS:
    _must_exist(ip); _must_exist(jp)
    if ARGS.pos_only:
        # Load raw GT meta (no coercion) and optionally prealign pos to the target frame
        raw_piece = _load_piece_raw(jp)
        piece_used = dict(raw_piece)
        if ARGS.prealign_gt:
            rgb_tmp = np.array(Image.open(ip).convert("RGB"))
            H_tmp, W_tmp = rgb_tmp.shape[:2]
            gt_mask_tmp = load_silhouette(ip)
            use_bbox_native, anchor_native, _iou_nat, bbox_native = _pick_alignment_fixed_size(
                raw_piece, gt_mask_tmp, W_tmp, H_tmp,
                size_mode=("fold" if getattr(ARGS, "size_into_scale", "keep") == "fold" else "keep")
            )
            anchor_target = 'centroid'
            if getattr(ARGS, 'map_to_axes', False):
                axes_bbox_tmp = detect_axes_bbox(rgb_tmp)
                use_bbox_target = True
                bbox_target = axes_bbox_tmp
            else:
                use_bbox_target = bool(ARGS.map_to_bbox)
                bbox_target = foreground_bbox(gt_mask_tmp) if use_bbox_target else None
            pos_t = transform_pos_between_frames(
                raw_piece, raw_piece["pos"], W_tmp, H_tmp,
                from_use_bbox=use_bbox_native, from_anchor=anchor_native,
                to_use_bbox=use_bbox_target, to_anchor=anchor_target,
                bbox_from=bbox_native, bbox_to=bbox_target
            )
            piece_used["pos"] = pos_t
        given_meta = {
            "type": _canon_type(piece_used.get("type","triangle")),
            "size": piece_used.get("size","na"),
            "angle": float(piece_used.get("angle",0)),
            "flip": bool(piece_used.get("flip", False)),
            "scale": float(piece_used.get("scale", 1.0)),
        }
        user_txt = (
            "Given meta: " + json.dumps(given_meta, ensure_ascii=False) + "\n"
            'Predict ONLY the position as {"pos":[x,y]} on 0..10 grid.'
        )
        messages.append({"role":"user","content":[{"type":"image","image": ip},{"type":"text","text": user_txt}]})
        # Assistant provides the (possibly prealigned) GT pos for few-shot supervision
        messages.append({"role":"assistant","content":[{"type":"text","text": json.dumps({"pos": piece_used["pos"]}, ensure_ascii=False)}]})
    else:
        ans_piece = _load_piece_dict(jp)
        messages.append({"role":"user","content":[{"type":"image","image": ip},{"type":"text","text":"Solve this image and return the single-piece JSON in the above schema."}]})
        messages.append({"role":"assistant","content":[{"type":"text","text": json.dumps(ans_piece, ensure_ascii=False)}]})

base_context = list(messages)
running_context = list(base_context)
rows = []

def _piece_to_vertices(piece):
    ptype = piece.get("type","triangle")
    size  = piece.get("size","na")
    if ptype == "triangle":
        base = _poly_triangle(size if size in ("large","medium","small") else "small")
    elif ptype == "square":
        base = _poly_square()
    else:
        base = _poly_parallelogram()
    v = _apply_transform(base, piece["pos"], piece["angle"], piece["flip"], piece["scale"],
                         px=1.0, offx=0.0, offy=0.0, anchor=getattr(ARGS,'anchor','corner'), invert_y=False)
    return v

# ---------- Test loop ----------
_num_rounds = (ARGS.iters if ARGS.iterative and ARGS.iters is not None else len(TEST_PAIRS))
_num_rounds = min(_num_rounds, len(TEST_PAIRS))
for idx in range(1, _num_rounds+1):
    ip, jp = TEST_PAIRS[idx-1]
    print(f"\n[TEST {idx}] start -> {os.path.basename(ip)}")
    print(f"[TEST {idx}] GT json={os.path.basename(jp)}")

    # Load image and silhouette early for prealignment & diagnostics
    rgb = np.array(Image.open(ip).convert("RGB"))
    H, W = rgb.shape[:2]
    gt_mask_png = load_silhouette(ip)
    # ---- Establish evaluation frame EARLY (used by IoU/overlays below) ----
    anchor_used = 'centroid' if ARGS.pos_only else getattr(ARGS, 'anchor', 'corner')
    if getattr(ARGS, 'map_to_axes', False):
        bbox_used = detect_axes_bbox(rgb)
        frame_used = "axes"
    else:
        if bool(ARGS.map_to_bbox):
            bbox_used = foreground_bbox(gt_mask_png)
            frame_used = "bbox"
        else:
            bbox_used = None
            frame_used = "full"
    bbox_used_tuple = bbox_used  # keep compatibility with helpers that expect this name

    test_messages = list(running_context)
    model_raw_pos = None
    frame_used = None
    bbox_used_tuple = None
    model_pos_used = None
    # anchor_used will be set below
    if ARGS.pos_only:
        gt_meta_piece = _load_piece_raw(jp)
        piece_used_meta = dict(gt_meta_piece)
        if ARGS.prealign_gt:
            use_bbox_native, anchor_native, _iou_nat, bbox_native = _pick_alignment_fixed_size(
                gt_meta_piece, gt_mask_png, W, H,
                size_mode=("fold" if getattr(ARGS, "size_into_scale", "keep") == "fold" else "keep")
            )
            anchor_target = 'centroid'
            if getattr(ARGS, 'map_to_axes', False):
                axes_bbox = detect_axes_bbox(rgb)
                use_bbox_target = True
                bbox_target = axes_bbox
            else:
                use_bbox_target = bool(ARGS.map_to_bbox)
                bbox_target = foreground_bbox(gt_mask_png) if use_bbox_target else None
            pos_t = transform_pos_between_frames(
                gt_meta_piece, gt_meta_piece["pos"], W, H,
                from_use_bbox=use_bbox_native, from_anchor=anchor_native,
                to_use_bbox=use_bbox_target, to_anchor=anchor_target,
                bbox_from=bbox_native, bbox_to=bbox_target
            )
            piece_used_meta["pos"] = pos_t
            print(f"[TEST {idx}] PREALIGN native=({use_bbox_native},{anchor_native}) -> target=({use_bbox_target},{anchor_target}); upper_bound(native)={_iou_nat:.4f}")
        given_meta = {
            "type": _canon_type(piece_used_meta.get("type","triangle")),
            "size": piece_used_meta.get("size","na"),
            "angle": float(piece_used_meta.get("angle",0)),
            "flip": bool(piece_used_meta.get("flip", False)),
            "scale": float(piece_used_meta.get("scale", 1.0)),
        }
        test_messages.insert(0, {"role":"system","content":[{"type":"text","text": INSTR_POS_ONLY}]})
        user_txt = (
            "Given meta: " + json.dumps(given_meta, ensure_ascii=False) + "\n"
            'Predict ONLY the position as {"pos":[x,y]} on 0..10 grid.'
        )
        test_messages.append({"role":"user","content":[{"type":"image","image": ip},{"type":"text","text": user_txt}]})
    else:
        test_messages.append({"role":"user","content":[{"type":"image","image": ip},{"type":"text","text":"Now solve this one. Return ONLY the single-piece JSON with fields: type,size,pos,angle,flip,scale."}]})

    chat_text = processor.apply_chat_template(test_messages, tokenize=False, add_generation_prompt=True)
    image_inputs, video_inputs = process_vision_info(test_messages)
    inputs = processor(text=[chat_text], images=image_inputs, videos=video_inputs, return_tensors="pt", padding=True)
    inputs = {k: (v.to("cpu") if isinstance(v, torch.Tensor) else v) for k, v in inputs.items()}
    print(f"[TEST {idx}] tokens_in={int(inputs['input_ids'].shape[-1])} images_in={len(image_inputs) if image_inputs is not None else 0}")

    print(f"[TEST {idx}] generating ...")
    with torch.no_grad():
        out_ids = model.generate(
            **inputs,
            max_new_tokens=ARGS.max_new,
            do_sample=True,
            temperature=0.3,
            top_p=0.9,
        )
    trimmed = [o[len(i):] for i, o in zip(inputs["input_ids"], out_ids)]
    output  = processor.batch_decode(trimmed, skip_special_tokens=True)[0].strip()
    print(f"[TEST {idx}] output_len={len(output)} preview=\n{output[:200]}{'...' if len(output)>200 else ''}")

    # Parse prediction
    m = re.search(r"\{.*\}", output, re.S)
    pred_piece = None
    pred_piece_fulljson = None
    if m:
        try:
            raw = json.loads(m.group(0))
            if ARGS.pos_only:
                gt_meta_piece = _load_piece_raw(jp)
                pos = raw.get("pos", raw.get("position"))
                if isinstance(pos, (list, tuple)) and len(pos)==2:
                    model_raw_pos = [float(pos[0]), float(pos[1])]
                    pos = [float(pos[0]), float(pos[1])]
                    pred_piece = {
                        "type": gt_meta_piece.get("type","triangle"),
                        "size": gt_meta_piece.get("size","na"),
                        "pos": pos,
                        "angle": float(gt_meta_piece.get("angle",0)),
                        "flip": bool(gt_meta_piece.get("flip", False)),
                        "scale": float(gt_meta_piece.get("scale", 1.0)),
                    }
                    # Canonicalize type for downstream renderer
                    pred_piece["type"] = _canon_type(pred_piece.get("type","triangle"))
                    # do not fold size into scale in pos_only mode
                    # do not coerce in pos_only mode; preserve angle/flip/scale
                    try:
                        _gt_meta_dbg = _load_piece_raw(jp)
                        meta_given = {
                            "type": _canon_type(_gt_meta_dbg.get("type","triangle")),
                            "size": _gt_meta_dbg.get("size","na"),
                            "angle": float(_gt_meta_dbg.get("angle",0)),
                            "flip": bool(_gt_meta_dbg.get("flip", False)),
                            "scale": float(_gt_meta_dbg.get("scale", 1.0)),
                        }
                        meta_used = {
                            "type": _canon_type(pred_piece.get("type","triangle")),
                            "size": pred_piece.get("size","na"),
                            "angle": float(pred_piece.get("angle",0)),
                            "flip": bool(pred_piece.get("flip", False)),
                            "scale": float(pred_piece.get("scale", 1.0)),
                        }
                        print(f"[TEST {idx}] POS-ONLY meta_given={meta_given} meta_used_for_render={meta_used}")
                        # ensure the renderer will see canonicalized types going forward
                        pred_piece["type"] = meta_used["type"]
                    except Exception as _e:
                        print(f"[TEST {idx}] [WARN] meta consistency check failed: {_e}")
                else:
                    pred_piece = None
            else:
                pred_piece = _coerce_to_schema(raw)
                pred_piece_fulljson = dict(pred_piece) if pred_piece is not None else None
        except Exception:
            pred_piece = None
        print(f"[TEST {idx}] json_ok={pred_piece is not None} normalized={pred_piece}")
        # --- POS-ONLY hard lock: force meta (type/size/angle/flip/scale) from GT JSON ---
        if ARGS.pos_only and pred_piece is not None:
            try:
                _gt_lock = _load_piece_raw(jp)  # raw GT without coercion/snap
                pred_piece["type"]  = _canon_type(_gt_lock.get("type", "triangle"))
                pred_piece["size"]  = _gt_lock.get("size", "na")
                pred_piece["angle"] = float(_gt_lock.get("angle", 0.0))
                pred_piece["flip"]  = bool(_gt_lock.get("flip", False))
                pred_piece["scale"] = float(_gt_lock.get("scale", 1.0))
                print(f"[TEST {idx}] POS-ONLY lock applied: type={pred_piece['type']} size={pred_piece['size']} angle={pred_piece['angle']} flip={pred_piece['flip']} scale={pred_piece['scale']}")
            except Exception as _e:
                print(f"[TEST {idx}] [WARN] POS-ONLY lock failed: {_e}")

    # Use the original PNG stem for output filenames
    stem = os.path.splitext(os.path.basename(ip))[0]
    pred_json_path = os.path.join(SAVE_DIR, f"{stem}_pred.json")
    # --- NEW IoU computation block (JSON vs JSON default) ---
    if pred_piece is not None:
        # Prepare paths for overlays using PNG stem
        size_mode_eval = (ARGS.size_into_scale if getattr(ARGS, "size_into_scale", "keep") in ("keep","fold") else "keep")
        overlay_axes_path = os.path.join(SAVE_DIR, f"{stem}_overlay_axes.png")           # grid overlay (preferred)
        overlay_pngspace_path = os.path.join(SAVE_DIR, f"{stem}_overlay_pngspace.png")  # pixel-space overlay (secondary)
        # Save prediction JSON for this test
        save_pred_json(pred_piece, pred_json_path)
        print(f"[TEST {idx}] saved pred JSON -> {os.path.basename(pred_json_path)}")

        # Build GT piece in the evaluation frame if needed
        try:
            gt_for_iou = _load_piece_raw(jp)
        except Exception as _e:
            print(f"[TEST {idx}] [WARN] load GT raw failed for IoU: {_e}")
            gt_for_iou = {"type":"triangle","size":"na","pos":[5.0,5.0],"angle":0.0,"flip":False,"scale":1.0}

        # If user开启了 prealign_gt，则把 GT 的 pos 从其 native 框架变换到当前评测框架 (bbox_used, anchor_used)
        try:
            if ARGS.prealign_gt:
                use_bbox_native, anchor_native, _iou_nat2, bbox_native = _pick_alignment_fixed_size(
                    gt_for_iou, gt_mask_png, W, H, size_mode="keep"
                )
                pos_t = transform_pos_between_frames(
                    gt_for_iou, gt_for_iou["pos"], W, H,
                    from_use_bbox=use_bbox_native, from_anchor=anchor_native,
                    to_use_bbox=bool(bbox_used is not None), to_anchor=anchor_used,
                    bbox_from=bbox_native, bbox_to=bbox_used
                )
                gt_for_iou = dict(gt_for_iou); gt_for_iou["pos"] = pos_t
        except Exception as _e:
            print(f"[TEST {idx}] [WARN] GT transform for IoU failed: {_e}")

        # Apply size policy (keep/fold)
        gt_eff   = _apply_size_policy(gt_for_iou, size_mode_eval)
        pred_eff = _apply_size_policy(pred_piece, size_mode_eval)

        # Render masks in the SAME evaluation frame as overlays (bbox_used, anchor_used)
        pred_mask_eval = render_mask(pred_eff, W, H, bbox=bbox_used, anchor=anchor_used)

        if ARGS.iou_mode == "json":
            gt_mask_eval   = render_mask(gt_eff,   W, H, bbox=bbox_used, anchor=anchor_used)
            iou, overflow, undercov = geometry_metrics(pred_mask_eval, gt_mask_eval)
            print(f"[TEST {idx}] IoU(JSON vs JSON, frame={('axes' if getattr(ARGS,'map_to_axes',False) else ('bbox' if bbox_used else 'full'))}, anchor={anchor_used}, size_mode={size_mode_eval}) = {iou:.4f}")

            # Overlays consistent with this IoU definition
            try:
                # 1) Save the preferred grid overlay to overlay_{idx}.png
                save_overlay_axes_with_grid(pred_eff, gt_eff, overlay_axes_path, title=os.path.basename(pred_json_path))
                print(f"[TEST {idx}] saved 0..10-grid overlay -> {os.path.basename(overlay_axes_path)}")
                # 2) Optionally save a pixel-space overlay (secondary, off by default).
                if os.environ.get("SAVE_PNGSPACE_OVERLAY", "0") == "1":
                    save_overlay_geom_pred_vs_gt(pred_eff, gt_eff, W, H, bbox_used, anchor_used, overlay_pngspace_path)
                    print(f"[TEST {idx}] saved PNG-space overlay -> {os.path.basename(overlay_pngspace_path)}")
            except Exception as _e:
                print(f"[TEST {idx}] [WARN] saving overlays failed: {_e}")
        else:
            # Legacy mode: Pred-JSON vs PNG silhouette
            iou, overflow, undercov = geometry_metrics(pred_mask_eval, gt_mask_png)
            print(f"[TEST {idx}] IoU(JSON vs PNG silhouette) = {iou:.4f}")
            try:
                # 1) Save the preferred grid overlay to overlay_{idx}.png
                save_overlay_axes_with_grid(pred_eff, gt_eff, overlay_axes_path, title=os.path.basename(pred_json_path))
                print(f"[TEST {idx}] saved 0..10-grid overlay -> {os.path.basename(overlay_axes_path)}")
                # 2) Optionally save a pixel-space overlay (secondary, off by default).
                if os.environ.get("SAVE_PNGSPACE_OVERLAY", "0") == "1":
                    save_overlay_geom_pred_vs_gt(pred_eff, gt_eff, W, H, bbox_used, anchor_used, overlay_pngspace_path)
                    print(f"[TEST {idx}] saved PNG-space overlay -> {os.path.basename(overlay_pngspace_path)}")
            except Exception as _e:
                print(f"[TEST {idx}] [WARN] saving overlays failed: {_e}")

    # --- OLD BLOCK DISABLED (kept for reference; not executed) ---
    # 直接调用外部 geometry.py，产出三张图（GT、Pred、Overlay）
    run_geometry_render(jp, pred_json_path, SAVE_DIR)

    # IoU vs PNG silhouette
    # --- alignment policy (deterministic) ---
    size_mode_used = getattr(ARGS, "size_into_scale", "auto")
    if ARGS.pos_only:
        # POS-ONLY: hard lock to centroid anchor; never auto-switch, but allow calibration if requested
        anchor_used = 'centroid'
        if getattr(ARGS, 'map_to_axes', False):
            axes_bbox = detect_axes_bbox(rgb)
            use_bbox = True
            bbox_used = axes_bbox
        else:
            use_bbox = bool(ARGS.map_to_bbox)
            bbox_used = foreground_bbox(gt_mask_png) if use_bbox else None
        # start with user's preference; allow calibration to override if --calibrate or if user asked for auto
        size_pref = getattr(ARGS, "size_into_scale", "keep")
        size_mode_used = "keep" if size_pref not in ("keep", "fold", "auto") else size_pref
        try:
            gt_piece_dbg_raw = _load_piece_raw(jp)
            if ARGS.prealign_gt:
                # Use the prealigned position for the upper-bound check in the target frame
                use_bbox_native, anchor_native, _iou_nat, bbox_native = _pick_alignment_fixed_size(gt_piece_dbg_raw, gt_mask_png, W, H, size_mode="keep")
                pos_t = transform_pos_between_frames(
                    gt_piece_dbg_raw, gt_piece_dbg_raw["pos"], W, H,
                    from_use_bbox=use_bbox_native, from_anchor=anchor_native,
                    to_use_bbox=use_bbox, to_anchor=anchor_used,
                    bbox_from=bbox_native, bbox_to=bbox_used
                )
                gt_piece_dbg_raw = dict(gt_piece_dbg_raw); gt_piece_dbg_raw["pos"] = pos_t
            # per-sample calibration (optional)
            if getattr(ARGS, "calibrate", False) or size_mode_used == "auto":
                tri_idx, sz_mode, ub = _calibrate_eval_shapes(
                    gt_piece_dbg_raw, gt_mask_png, W, H, bbox_used, anchor_used,
                    try_variants=getattr(ARGS, "tri_variants", 2),
                    respect_user_size_mode=None if size_mode_used == "auto" else size_mode_used
                )
                size_mode_used = sz_mode
                if getattr(ARGS, "verbose_calib", False):
                    print(f"[CALIB] choose tri_variant={tri_idx} size_mode={size_mode_used}  ub={ub:.4f}")
                iou_upper = ub
            else:
                # no calibration: just compute UB with the chosen size policy
                iou_upper = iou_score(
                    render_mask(_apply_size_policy(gt_piece_dbg_raw, size_mode_used if size_mode_used != "auto" else "keep"), W, H,
                                bbox=bbox_used, anchor=anchor_used),
                    gt_mask_png,
                )
            frame_str = "axes" if getattr(ARGS, 'map_to_axes', False) else ("bbox" if use_bbox else "full")
            print(f"[TEST {idx}] POS-ONLY fixed-frame: frame={frame_str}, anchor={anchor_used}, size_mode={size_mode_used}; upper_bound={iou_upper:.4f}")
        except Exception as _e:
            print(f"[TEST {idx}] [WARN] upper-bound IoU check failed: {_e}")
            iou_upper = None
    elif getattr(ARGS, "auto_align", False):
        gt_piece_dbg = _load_piece_raw(jp)
        use_bbox, anchor_used, iou_upper, bbox_used, size_mode_used = _auto_pick_alignment(gt_piece_dbg, gt_mask_png, W, H)
        print(f"[TEST {idx}] AUTO-ALIGN pick -> map_to_bbox={use_bbox}  anchor={anchor_used}  size_mode={size_mode_used}  upper_bound={iou_upper:.4f}")
    else:
        anchor_used = getattr(ARGS, 'anchor', 'corner')
        if getattr(ARGS, 'map_to_axes', False):
            axes_bbox = detect_axes_bbox(rgb)
            use_bbox = True
            bbox_used = axes_bbox
            try:
                gt_piece_dbg_raw = _load_piece_raw(jp)
                iou_upper_full = iou_score(render_mask(_apply_size_policy(gt_piece_dbg_raw, size_mode_used if size_mode_used != "auto" else "keep"), W, H, bbox=None, anchor=anchor_used), gt_mask_png)
                iou_upper_axes = iou_score(render_mask(_apply_size_policy(gt_piece_dbg_raw, size_mode_used if size_mode_used != "auto" else "keep"), W, H, bbox=axes_bbox, anchor=anchor_used), gt_mask_png)
                iou_upper = max(iou_upper_full, iou_upper_axes)
                if size_mode_used == "auto":
                    try:
                        iou_keep = iou_score(render_mask(_apply_size_policy(gt_piece_dbg_raw, "keep"), W, H, bbox=bbox_used, anchor=anchor_used), gt_mask_png)
                        iou_fold = iou_score(render_mask(_apply_size_policy(gt_piece_dbg_raw, "fold"), W, H, bbox=bbox_used, anchor=anchor_used), gt_mask_png)
                        size_mode_used = "fold" if iou_fold > iou_keep else "keep"
                        print(f"[TEST {idx}] size_mode(auto) -> keep={iou_keep:.4f} fold={iou_fold:.4f} | choose {size_mode_used}")
                    except Exception as _e:
                        print(f"[TEST {idx}] [WARN] size-mode auto selection failed: {_e}")
                        size_mode_used = "keep"
                print(f"[TEST {idx}] Upper-bound IoU (JSON→full) = {iou_upper_full:.4f}  (JSON→axes) = {iou_upper_axes:.4f}  size_mode={size_mode_used}")
            except Exception as _e:
                print(f"[TEST {idx}] [WARN] upper-bound IoU check failed: {_e}")
                iou_upper = None
        else:
            use_bbox = ARGS.map_to_bbox
            bbox_auto = foreground_bbox(gt_mask_png) if use_bbox else None
            bbox_used = bbox_auto
            try:
                gt_piece_dbg_raw = _load_piece_raw(jp)
                iou_upper_full = iou_score(render_mask(_apply_size_policy(gt_piece_dbg_raw, size_mode_used if size_mode_used != "auto" else "keep"), W, H, bbox=None, anchor=anchor_used), gt_mask_png)
                if bbox_auto is not None:
                    iou_upper_bbox = iou_score(render_mask(_apply_size_policy(gt_piece_dbg_raw, size_mode_used if size_mode_used != "auto" else "keep"), W, H, bbox=bbox_auto, anchor=anchor_used), gt_mask_png)
                else:
                    iou_upper_bbox = 0.0
                iou_upper = max(iou_upper_full, iou_upper_bbox)
                if use_bbox and iou_upper_full > iou_upper_bbox:
                    use_bbox = False
                    bbox_used = None
                if size_mode_used == "auto":
                    try:
                        iou_keep = iou_score(render_mask(_apply_size_policy(gt_piece_dbg_raw, "keep"), W, H, bbox=bbox_used, anchor=anchor_used), gt_mask_png)
                        iou_fold = iou_score(render_mask(_apply_size_policy(gt_piece_dbg_raw, "fold"), W, H, bbox=bbox_used, anchor=anchor_used), gt_mask_png)
                        size_mode_used = "fold" if iou_fold > iou_keep else "keep"
                        print(f"[TEST {idx}] size_mode(auto) -> keep={iou_keep:.4f} fold={iou_fold:.4f} | choose {size_mode_used}")
                    except Exception as _e:
                        print(f"[TEST {idx}] [WARN] size-mode auto selection failed: {_e}")
                        size_mode_used = "keep"
                print(f"[TEST {idx}] Upper-bound IoU (JSON→full) = {iou_upper_full:.4f}  (JSON→bbox) = {iou_upper_bbox:.4f}  size_mode={size_mode_used}")
            except Exception as _e:
                print(f"[TEST {idx}] [WARN] upper-bound IoU check failed: {_e}")
                iou_upper = None

    # --- Compute frame_used and bbox_used_tuple for diagnostics ---
    frame_used = 'axes' if getattr(ARGS,'map_to_axes',False) else ('bbox' if use_bbox else 'full')
    bbox_used_tuple = bbox_used if bbox_used is not None else (None)

    # --- Normalize predicted 'pos' into the chosen frame/anchor so size/angle visually match ---
    model_use_bbox = bool(ARGS.map_to_bbox)
    model_anchor   = getattr(ARGS, 'anchor', 'corner')
    if pred_piece is not None:
        if ARGS.pos_only:
            # Hard lock: in POS-ONLY we do not change the coordinate frame
            pass
        elif (model_use_bbox != use_bbox or model_anchor != anchor_used):
            try:
                old_pos = list(pred_piece["pos"])
                new_pos = transform_pos_between_frames(
                    pred_piece, old_pos, W, H,
                    from_use_bbox=model_use_bbox, from_anchor=model_anchor,
                    to_use_bbox=use_bbox,       to_anchor=anchor_used,
                    bbox=bbox_used if bbox_used is not None else foreground_bbox(gt_mask_png)
                )
                pred_piece["pos"] = new_pos
                print(f"[TEST {idx}] POS transformed {old_pos} ({'bbox' if model_use_bbox else 'full'},{model_anchor}) → {new_pos} ({'bbox' if use_bbox else 'full'},{anchor_used})")
            except Exception as _e:
                print(f"[TEST {idx}] [WARN] failed to transform pos across frames/anchors: {_e}")

    # --- Save model_pos_used for diagnostics ---
    if ARGS.pos_only and pred_piece is not None:
        model_pos_used = list(pred_piece['pos'])
        print(f"[DIAG] pos(raw {model_raw_pos}) -> pos(used {model_pos_used}) | frame_used={frame_used} (locked), anchor_used={anchor_used} (locked)")

    iou = 0.0
    overflow = 0.0
    undercov = 0.0
    overlay_path = os.path.join(SAVE_DIR, f"overlay_{idx}.png")
    pred_render_path = os.path.join(SAVE_DIR, f"pred_render_{idx}.png")
    overlay_json_path = os.path.join(SAVE_DIR, f"overlay_json_vs_pred_{idx}.png")
    iou_modelpos_gtmeta = None
    iou_gtjson_used = None
    iou_model_fulljson = None
    if pred_piece is not None:
        # Always keep size/scale as-is for JSON→PNG rendering to match geometry.py
        policy_mode = "keep"
        pred_piece_eff = _apply_size_policy(pred_piece, policy_mode)
        pred_mask = render_mask(pred_piece_eff, W, H, bbox=bbox_used, anchor=anchor_used)
        # IoU + error breakdown
        iou, overflow, undercov = geometry_metrics(pred_mask, gt_mask_png)
        print(f"[TEST {idx}] IoU={iou:.4f}  overflow(extra%)={overflow*100:.2f}  undercov(missing%)={undercov*100:.2f}")

        # === Overlays ===
        # (A) Pixel-accurate overlay using geometry backend (matches IoU space)
        try:
            gt_piece_for_json_overlay = _load_piece_raw(jp)
            # For pixel-space overlay, respect the current evaluation frame but ALWAYS keep size as-is
            gt_piece_pix = _apply_size_policy(gt_piece_for_json_overlay, "keep")
            pred_piece_pix = _apply_size_policy(pred_piece_eff, "keep")
            save_overlay_geom_pred_vs_gt(
                pred_piece_pix, gt_piece_pix,
                W, H,
                bbox_used,  # same frame as evaluation
                anchor_used,
                overlay_path,
            )
            print(f"[TEST {idx}] saved PNG-space overlay -> {os.path.basename(overlay_path)}")
        except Exception as _e:
            print(f"[TEST {idx}] [WARN] failed to save PNG-space overlay: {_e}")

        # (B) Grid (0..10) overlay: NO frame transform, NO y inversion, for human-facing inspection
        try:
            gt_piece_grid  = _load_piece_raw(jp)
            pred_piece_grid = dict(pred_piece)  # use model JSON as-is
            # keep size/scale exactly
            gt_piece_grid  = _apply_size_policy(gt_piece_grid,  "keep")
            pred_piece_grid = _apply_size_policy(pred_piece_grid, "keep")
            save_overlay_axes_with_grid(
                pred_piece_grid,
                gt_piece_grid,
                os.path.join(SAVE_DIR, f"overlay_axes_{idx}.png"),
                title="pred_{}.json".format(idx)
            )
            print(f"[TEST {idx}] saved 0..10-grid overlay -> overlay_axes_{idx}.png")
        except Exception as _e:
            print(f"[TEST {idx}] [WARN] failed to save 0..10-grid overlay: {_e}")