#!/usr/bin/env python3
# -*- coding: utf-8 -*-
"""
qwen_geom_param_demo.py  (replaces qwen_angle_demo.py)
- 通用评测脚本：固定 GT 中除一个目标参数外的所有字段，只让 Qwen 预测该参数
- 目前支持：angle（角度，度） 与 scale（缩放因子）
- 预测后把该字段写入 pred.json（其余字段直接继承 GT）
- 用 geometry.py 渲染 GT / Pred / Overlay
- 通过渲染图阈值化计算 IoU，在终端逐样本与汇总打印
"""

import os
import re
import sys
import cv2
import json
import time
import copy 
import glob
import math
import argparse
import subprocess
from dataclasses import dataclass
from typing import Dict, Any, Tuple, List

import numpy as np
from PIL import Image
import torch
import csv

# 让 tokenizers 安静点（避免 fork 警告）
os.environ.setdefault("TOKENIZERS_PARALLELISM", "false")

# ===== 模型加载（Qwen2.5-VL 本地权重） =====
from transformers import AutoModelForVision2Seq, AutoModelForCausalLM, AutoTokenizer, AutoProcessor

SCRIPT_DIR = os.path.dirname(os.path.abspath(__file__))
GEOMETRY_PY = os.path.join(SCRIPT_DIR, "geometry.py")


@dataclass
class ModelBundle:
    tok: Any
    proc: Any
    mdl: Any


def _resolve_qwen_snapshot(base_dir: str) -> str:
    """
    Hugging Face cache layout:
    ~/.cache/huggingface/hub/models--ORG--NAME/{blobs,refs,snapshots/<hash>}
    We need the real snapshot folder that contains config.json, model*.bin, etc.
    """
    snaps = os.path.join(base_dir, "snapshots")
    if os.path.isdir(snaps):
        # pick the most recent snapshot by mtime
        cands = [os.path.join(snaps, d) for d in os.listdir(snaps) if os.path.isdir(os.path.join(snaps, d))]
        if cands:
            cands.sort(key=lambda p: os.path.getmtime(p), reverse=True)
            return cands[0]
    # fall back to base_dir (in case of symlinked/flattened cache)
    return base_dir


def load_qwen(model_name: str = None, device: str = "cpu") -> ModelBundle:
    print("[INFO] Loading local Qwen2.5-VL-3B-Instruct (no download)...")
    base_dir = os.path.expanduser("~/.cache/huggingface/hub/models--Qwen--Qwen2.5-VL-3B-Instruct")
    # Allow override via env var if user moved the cache
    base_dir = os.environ.get("QWEN_LOCAL_DIR", base_dir)
    local_dir = _resolve_qwen_snapshot(base_dir)
    print(f"[INFO] Using snapshot: {local_dir}")

    use_cpu = (str(device).lower() == "cpu")
    if not use_cpu and str(device).lower() == "mps":
        os.environ.setdefault("PYTORCH_ENABLE_MPS_FALLBACK", "1")

    tok = AutoTokenizer.from_pretrained(local_dir, use_fast=True, trust_remote_code=True, local_files_only=True)
    proc = AutoProcessor.from_pretrained(local_dir, trust_remote_code=True, local_files_only=True)

    # Try Vision2Seq first (newer Qwen2.5-VL uses this); fall back to CausalLM if needed.
    try:
        mdl = AutoModelForVision2Seq.from_pretrained(
            local_dir,
            device_map=None,
            torch_dtype=torch.float32,
            trust_remote_code=True,
            local_files_only=True,
        )
        mdl.to("cpu" if use_cpu else "mps")
        print("[INFO] Loaded with AutoModelForVision2Seq (local)")
    except Exception as e1:
        print(f"[WARN] Vision2Seq loader failed: {e1}\n[INFO] Falling back to AutoModelForCausalLM ...")
        mdl = AutoModelForCausalLM.from_pretrained(
            local_dir,
            device_map=None,
            torch_dtype=torch.float32,
            trust_remote_code=True,
            local_files_only=True,
        )
        mdl.to("cpu" if use_cpu else "mps")
        print("[INFO] Loaded with AutoModelForCausalLM (local fallback)")
    mdl.eval()
    return ModelBundle(tok=tok, proc=proc, mdl=mdl)


# =========================
#     统一的推理函数
# =========================

ANGLE_SYS = "You are a vision assistant for geometry evaluation."
ANGLE_USER = (
    "You are given an image of a single geometric piece inside a 0..10 grid.\n"
    "Your job: RETURN ONLY a compact JSON with the rotation angle of the piece (in degrees).\n"
    "Convention: positive=counter-clockwise, negative=clockwise; range [-180, 180]; 0 means upright.\n"
    "Respond EXACTLY like: {\"angle\": <float>}"
)

# Two-piece ONLY-angle prompt
ANGLE2_SYS = "You are a vision assistant for geometry evaluation."
ANGLE2_USER = (
    "You are given an image that contains TWO geometric pieces inside a 0..10 grid.\n"
    "Return ONLY a compact JSON with the rotation angles (in degrees) for BOTH pieces, in the SAME ORDER as the ground-truth pieces array (piece[0] then piece[1]).\n"
    "Convention: positive=CCW, negative=CW; range [-180, 180].\n"
    "Respond EXACTLY like: {\"angle\": [<float>, <float>]}"
)

SCALE_SYS = "You are a vision assistant for geometry evaluation."
SCALE_USER = (
    "You are given an image of a single geometric piece inside a 0..10 grid.\n"
    "Your job: RETURN ONLY a compact JSON with the size scale factor of the piece.\n"
    "Interpret 'scale' as the multiplicative factor applied to a canonical template to match the piece silhouette.\n"
    "Respond EXACTLY like: {\"scale\": <float>}  (no extra keys or text)\n"
    "Typical values are around 0.5~1.2; do not include units."
)



def _chat_generate_json(mbundle: ModelBundle, image_path: str, sys_prompt: str, user_prompt: str,
                        key_name: str, max_new_tokens: int = 64):
    """通用：让 Qwen 输出 {"<key_name>": number} 或 {"pos": [x, y]} 或多字段，返回对应值。支持 two-piece 解析。"""
    messages = [
        {"role": "system", "content": sys_prompt},
        {"role": "user", "content": [
            {"type": "image", "image": image_path},
            {"type": "text", "text": user_prompt},
        ]},
    ]
    conversation = mbundle.tok.apply_chat_template(messages, add_generation_prompt=True, tokenize=False)
    inputs = mbundle.proc(
        text=[conversation],
        images=[Image.open(image_path).convert("RGB")],
        return_tensors="pt"
    )
    dev = next(mbundle.mdl.parameters()).device
    inputs = inputs.to(dev)
    gen = mbundle.mdl.generate(
        **inputs,
        max_new_tokens=max_new_tokens,
        do_sample=False,
        temperature=0.0,
        repetition_penalty=1.0,
    )
    out_text = mbundle.tok.decode(gen[0], skip_special_tokens=True)
    # Debug dump if env variable is set
    if os.environ.get("QWEN_DUMP_RAW"):
        try:
            with open(os.environ["QWEN_DUMP_RAW"], "a", encoding="utf-8") as _f:
                _f.write("\n==== NEW SAMPLE ====\n" + out_text + "\n")
        except Exception:
            pass
    # 捕获 {"key": number} 或 {"pos": [x, y]} 或多字段
    try:
        # Prefer fenced ```json ... ``` blocks
        fence = re.search(r"```json\s*(\{[\s\S]*?\})\s*```", out_text, flags=re.IGNORECASE)
        if fence:
            json_str = fence.group(1)
        else:
            # fallback to the LAST {...} block in the output (most likely the model's answer)
            blocks = list(re.finditer(r"\{[\s\S]*?\}", out_text))
            json_str = blocks[-1].group(0) if blocks else None
        if not json_str:
            print(f"[WARN] Qwen {key_name} JSON not found, fallback 0.")
            if key_name == "pos":
                return [0.0, 0.0]
            if key_name == "pos2":
                return [[0.0, 0.0], [0.0, 0.0]]
            if key_name == "angle+pos2":
                return {"angle": [0.0, 0.0], "pos": [[0.0, 0.0], [0.0, 0.0]]}
            return 0.0
        js = json.loads(json_str)
        if key_name == "pos":
            val = js.get("pos")
            if isinstance(val, list) and len(val) == 2:
                return [float(val[0]), float(val[1])]
            else:
                print(f"[WARN] Qwen pos list not found or invalid, fallback [0,0].")
                return [0.0, 0.0]
        elif key_name == "pos2":
            val = js.get("pos")
            if isinstance(val, list) and len(val) == 2 and all(isinstance(x, list) and len(x) == 2 for x in val):
                return [[float(val[0][0]), float(val[0][1])], [float(val[1][0]), float(val[1][1])]]
            # fallback: regex extract two bracketed pairs
            coords = re.findall(r"\[\s*([\-\.\deE]+)\s*,\s*([\-\.\deE]+)\s*\]", out_text)
            if len(coords) >= 2:
                try:
                    return [[float(coords[0][0]), float(coords[0][1])], [float(coords[1][0]), float(coords[1][1])]]
                except Exception:
                    pass
            print(f"[WARN] Qwen pos2 list not found or invalid, fallback [[0,0],[0,0]].")
            return [[0.0, 0.0], [0.0, 0.0]]
        elif key_name == "angle+pos2":
            # Should return dict: {"angle": [float,float], "pos": [[x0,y0],[x1,y1]]}
            # Try to coerce fields if only singletons present
            angle = js.get("angle")
            pos = js.get("pos")
            # Try to coerce to required lists
            if isinstance(angle, list) and len(angle) == 2:
                angle_list = [float(angle[0]), float(angle[1])]
            elif isinstance(angle, (int, float, str)):
                try:
                    angle_list = [float(angle), 0.0]
                except Exception:
                    angle_list = [0.0, 0.0]
            else:
                angle_list = [0.0, 0.0]
            if isinstance(pos, list) and len(pos) == 2 and all(isinstance(x, list) and len(x) == 2 for x in pos):
                pos_list = [[float(pos[0][0]), float(pos[0][1])], [float(pos[1][0]), float(pos[1][1])]]
            else:
                # fallback: regex extract two bracketed pairs
                coords = re.findall(r"\[\s*([\-\.\deE]+)\s*,\s*([\-\.\deE]+)\s*\]", out_text)
                if len(coords) >= 2:
                    try:
                        pos_list = [[float(coords[0][0]), float(coords[0][1])], [float(coords[1][0]), float(coords[1][1])]]
                    except Exception:
                        pos_list = [[0.0, 0.0], [0.0, 0.0]]
                else:
                    pos_list = [[0.0, 0.0], [0.0, 0.0]]
            return {"angle": angle_list, "pos": pos_list}
        elif "+" in key_name:
            # e.g., key_name="angle+pos": return full dict
            return js
        elif key_name == "angle2":
            val = js.get("angle")
            if isinstance(val, list) and len(val) == 2:
                return [float(val[0]), float(val[1])]
            # fallback: take first two numbers from text
            nums = re.findall(r"[-+]?\d+(?:\.\d+)?(?:[eE][-+]?\d+)?", out_text)
            if len(nums) >= 2:
                try:
                    return [float(nums[0]), float(nums[1])]
                except Exception:
                    pass
            print("[WARN] Qwen angle2 list not found or invalid, fallback [0.0, 0.0].")
            return [0.0, 0.0]
        else:
            val = js.get(key_name)
            if val is None:
                print(f"[WARN] Qwen {key_name} field missing, fallback 0.")
                return 0.0
            val = float(val)
            if key_name == "angle":
                # 角度归一化到 [-180, 180]
                val = ((val + 180.0) % 360.0) - 180.0
            return val
    except Exception as e:
        print(f"[WARN] Failed to parse Qwen output for {key_name}: {e}")
        if key_name == "pos":
            return [0.0, 0.0]
        if key_name == "pos2":
            # fallback: try regex
            coords = re.findall(r"\[\s*([\-\.\deE]+)\s*,\s*([\-\.\deE]+)\s*\]", out_text)
            if len(coords) >= 2:
                try:
                    return [[float(coords[0][0]), float(coords[0][1])], [float(coords[1][0]), float(coords[1][1])]]
                except Exception:
                    pass
            return [[0.0, 0.0], [0.0, 0.0]]
        if key_name == "angle+pos2":
            return {"angle": [0.0, 0.0], "pos": [[0.0, 0.0], [0.0, 0.0]]}
        if key_name == "angle2":
            return [0.0, 0.0]
        return 0.0
def run_qwen_angle2(mbundle: ModelBundle, image_path: str):
    return _chat_generate_json(mbundle, image_path, ANGLE2_SYS, ANGLE2_USER, "angle2")
POS_SYS = "You are a vision assistant for geometry evaluation."
POS_USER = (
    "You are given an image of a single geometric piece inside a 0..10 grid.\n"
    "Your job: RETURN ONLY a compact JSON with the centroid position of the piece.\n"
    "Respond EXACTLY like: {\"pos\": [<float>, <float>]}"
)

# Two-piece prompts
POS2_SYS = "You are a vision assistant for geometry evaluation."
POS2_USER = (
    "You are given an image that contains TWO geometric pieces inside a 0..10 grid.\n"
    "Return ONLY a compact JSON with the centroid positions of BOTH pieces, in the SAME ORDER as ground-truth pieces array (piece[0] then piece[1]).\n"
    "Respond EXACTLY like: {\"pos\": [[<float>, <float>], [<float>, <float>]]}"
)

ANGLE_POS_SYS = "You are a vision assistant for geometry evaluation."
ANGLE_POS_USER = (
    "You are given an image of a single geometric piece inside a 0..10 grid.\n"
    "Your job: RETURN ONLY a compact JSON with BOTH the rotation angle (in degrees) and centroid position of the piece.\n"
    "Convention: positive=counter-clockwise, negative=clockwise; range [-180, 180]; 0 means upright.\n"
    "Respond EXACTLY like: {\"angle\": <float>, \"pos\": [<float>, <float>]}"
)

ANGLE_POS2_SYS = "You are a vision assistant for geometry evaluation."
ANGLE_POS2_USER = (
    "You are given an image that contains TWO geometric pieces inside a 0..10 grid.\n"
    "Return ONLY a compact JSON with BOTH the rotation angles and centroid positions for the two pieces, in the SAME ORDER as ground-truth pieces array (piece[0] then piece[1]).\n"
    "Angle convention: positive=CCW, negative=CW; range [-180,180].\n"
    "Respond EXACTLY like: {\"angle\": [<float>, <float>], \"pos\": [[<float>, <float>], [<float>, <float>]]}"
)


def run_qwen_pos(mbundle: ModelBundle, image_path: str) -> Tuple[float, float]:
    val = _chat_generate_json(mbundle, image_path, POS_SYS, POS_USER, "pos")
    return val

def run_qwen_pos2(mbundle: ModelBundle, image_path: str):
    return _chat_generate_json(mbundle, image_path, POS2_SYS, POS2_USER, "pos2")

def run_qwen_angle_pos(mbundle: ModelBundle, image_path: str):
    val = _chat_generate_json(mbundle, image_path, ANGLE_POS_SYS, ANGLE_POS_USER, "angle+pos")
    # Should return dict: {"angle": float, "pos": [x, y]}
    return val

def run_qwen_angle_pos2(mb, image_path):
    return _chat_generate_json(mb, image_path, ANGLE_POS2_SYS, ANGLE_POS2_USER, "angle+pos2")


def run_qwen_angle(mbundle: ModelBundle, image_path: str) -> float:
    return _chat_generate_json(mbundle, image_path, ANGLE_SYS, ANGLE_USER, "angle")


def run_qwen_scale(mbundle: ModelBundle, image_path: str) -> float:
    return _chat_generate_json(mbundle, image_path, SCALE_SYS, SCALE_USER, "scale")


# ========== 数据与工具 ==========

def list_pairs(img_dir: str, json_dir: str, exts=(".png", ".jpg", ".jpeg")) -> List[Tuple[str, str]]:
    pairs = []
    for p in sorted(glob.glob(os.path.join(img_dir, "*"))):
        if not p.lower().endswith(exts):
            continue
        base = os.path.basename(p)
        name = os.path.splitext(base)[0]
        jp = os.path.join(json_dir, f"{name}.json")
        if os.path.exists(jp):
            pairs.append((p, jp))
    return pairs


def must_exist(path: str):
    if not os.path.exists(path):
        raise FileNotFoundError(f"[PATH ERROR] Not found: {path}")


def load_json(path: str) -> Dict[str, Any]:
    with open(path, "r", encoding="utf-8") as f:
        return json.load(f)


def save_json(obj: Dict[str, Any], path: str):
    with open(path, "w", encoding="utf-8") as f:
        json.dump(obj, f, ensure_ascii=False, indent=2)


def ensure_dir(d: str):
    os.makedirs(d, exist_ok=True)


def angle_delta_deg(a: float, b: float) -> float:
    d = abs(a - b) % 360.0
    if d > 180.0:
        d = 360.0 - d
    return d


# ========== 渲染 & IoU ==========
# 通过 geometry.py 生成：
#   - gt_render.png
#   - pred_render.png
#   - overlay_json_vs_pred.png
# 然后用 gt/pred 两张渲染图做阈值化求 IoU

def run_geometry_render(gt_json_path: str, pred_json_path: str, out_dir: str):
    ensure_dir(out_dir)

    gtr = os.path.join(out_dir, "gt_render.png")
    prd = os.path.join(out_dir, "pred_render.png")
    ovl = os.path.join(out_dir, "overlay_json_vs_pred.png")

    # 1) GT
    subprocess.run(
        [
            sys.executable, GEOMETRY_PY,
            "--ann", gt_json_path,
            "--save_png", gtr,
            "--no_show",
        ],
        check=True
    )
    # 2) Pred
    subprocess.run(
        [
            sys.executable, GEOMETRY_PY,
            "--ann", pred_json_path,
            "--save_png", prd,
            "--no_show",
        ],
        check=True
    )
    # 3) Overlay
    subprocess.run(
        [
            sys.executable, GEOMETRY_PY,
            "--overlay_gt_json", gt_json_path,
            "--overlay_pred_json", pred_json_path,
            "--save_png", ovl,
            "--no_show",
        ],
        check=True
    )
    return gtr, prd, ovl


def _to_mask(img: Image.Image) -> np.ndarray:
    """
    将 geometry 渲染出的图转为 0/1 mask。
    经验阈值：图里前景是“填充的多边形”，一般比背景更亮（或更彩）。
    用 HSV 的 V 分量做阈值（> 0.6）。若配色有变化，可调整阈值。
    """
    im = np.array(img.convert("RGB"))
    hsv = cv2.cvtColor(im, cv2.COLOR_RGB2HSV).astype(np.float32) / 255.0
    v = hsv[..., 2]
    mask = (v > 0.6).astype(np.uint8)
    # 取最大连通块以去掉网格/文字干扰
    num, lbl = cv2.connectedComponents(mask)
    if num <= 1:
        return mask
    best = 0
    best_area = 0
    for i in range(1, num):
        area = (lbl == i).sum()
        if area > best_area:
            best_area = area
            best = i
    return (lbl == best).astype(np.uint8)


def iou_from_pair(gt_png: str, pred_png: str) -> float:
    gt = Image.open(gt_png)
    pr = Image.open(pred_png)

    if gt.size != pr.size:
        pr = pr.resize(gt.size, Image.NEAREST)

    m1 = _to_mask(gt)
    m2 = _to_mask(pr)

    inter = np.logical_and(m1, m2).sum()
    union = np.logical_or(m1, m2).sum()
    if union == 0:
        return 0.0
    return inter / float(union)


def run_external_iouEvaluator(batch_run_dir: str, gt_dir: str) -> Tuple[str, Dict[str, float]]:
    """Run iou_from_overlay.py with the unified flags and return (csv_path, {sample->IoU})."""
    summary_csv_path = os.path.join(batch_run_dir, "iou_summary.csv")
    cmd = [
        sys.executable,
        os.path.join(SCRIPT_DIR, "iou_from_overlay.py"),
        "--batch_my_run", batch_run_dir,
        "--gt_dir", gt_dir,
        "--map_to_axes_geom",
        "--gt_anchor", "centroid",
        "--pred_anchor", "centroid",
        "--size_mode_geom", "keep",
        "--dilate", "2",
        "--write_iou_txt",
        "--summary_csv", summary_csv_path,
        "--save_debug",
        "--verbose",
    ]
    print("[EVAL] Running overlay IoU evaluator...\n ", " ".join(cmd))
    subprocess.run(cmd, check=True)
    print(f"[EVAL] Overlay-based IoU summary written to: {summary_csv_path}")

    # Parse summary CSV (column names may vary slightly; try common variants)
    iou_map: Dict[str, float] = {}
    try:
        with open(summary_csv_path, "r", encoding="utf-8") as f:
            reader = csv.DictReader(f)
            # guess column names
            sample_keys = [k for k in reader.fieldnames or [] if k.lower() in ("sample", "name", "id", "file", "basename")]
            iou_keys = [k for k in reader.fieldnames or [] if k.lower() in ("iou", "iou_score", "mean_iou", "iou")]
            sample_key = sample_keys[0] if sample_keys else None
            iou_key = iou_keys[0] if iou_keys else None
            for row in reader:
                key = row.get(sample_key) if sample_key else row.get("sample") or row.get("name")
                iou_str = row.get(iou_key) if iou_key else row.get("IoU") or row.get("iou")
                if key is None or iou_str is None:
                    continue
                try:
                    iou_map[os.path.splitext(os.path.basename(key))[0]] = float(iou_str)
                except Exception:
                    pass
    except Exception as e:
        print(f"[WARN] Failed to parse IoU summary CSV: {e}")
    return summary_csv_path, iou_map


# ========== 主流程 ==========

def main():
    ap = argparse.ArgumentParser()
    ap.add_argument("--img_dir", type=str, required=True, help="图片目录（与 JSON 同名）")
    ap.add_argument("--json_dir", type=str, required=True, help="GT JSON 目录")
    ap.add_argument("--n", type=int, default=8, help="最多多少个样本（--all 覆盖）")
    ap.add_argument("--seed", type=int, default=2025)
    ap.add_argument("--all", action="store_true", help="处理目录下所有样本")
    ap.add_argument("--out_dir", type=str, default="./runs/param_eval")
    ap.add_argument("--model", type=str, default="Qwen/Qwen2.5-VL-7B-Instruct")
    ap.add_argument("--device", type=str, default="cpu", choices=["cpu", "mps"], help="force device; default cpu to avoid MPS matmul bugs")

    # 关键：预测哪个字段
    ap.add_argument("--predict", type=str, default="angle",
                    choices=["angle", "scale", "size", "pos", "angle+pos"],
                    help="只预测该字段，其余字段固定为 GT。'size' 为 'scale' 同义词。")

    args = ap.parse_args()
    if args.predict == "size":
        args.predict = "scale"

    must_exist(args.img_dir)
    must_exist(args.json_dir)
    ensure_dir(args.out_dir)

    pairs = list_pairs(args.img_dir, args.json_dir)
    if not pairs:
        raise FileNotFoundError(f"No pairs in {args.img_dir} vs {args.json_dir}")
    if not args.all:
        pairs = pairs[: args.n]

    ts = time.strftime("%Y%m%d_%H%M%S")
    root = os.path.join(args.out_dir, ts)
    ensure_dir(root)
    print(f"[OUT] results -> {root}")
    print(f"[TASK] predict = {args.predict}  (others fixed from GT)")

    # 载入模型
    mb = load_qwen(args.model, device=args.device)

    err_list: List[float] = []
    rows = []
    scale_ratio_list: List[float] = []

    for idx, (img_path, gt_json_path) in enumerate(pairs, 1):
        name = os.path.splitext(os.path.basename(img_path))[0]
        work = os.path.join(root, name)
        ensure_dir(work)
        gt = load_json(gt_json_path)
        # Detect two-piece schema
        is_two = isinstance(gt.get("pieces"), list) and len(gt["pieces"]) == 2

        # 1) 让 Qwen 只预测目标字段
        if args.predict == "angle":
            if is_two:
                pred_val = run_qwen_angle2(mb, img_path)  # [a0,a1]
                gt_val = [float(gt["pieces"][0].get("angle",0.0)), float(gt["pieces"][1].get("angle",0.0))]
                pred = copy.deepcopy(gt)
                pred["pieces"][0]["angle"] = float(pred_val[0])
                pred["pieces"][1]["angle"] = float(pred_val[1])
            else:
                pred_val = run_qwen_angle(mb, img_path)
                gt_val = float(gt.get("angle", 0.0))
                pred = copy.deepcopy(gt)
                pred["angle"] = float(pred_val)
        elif args.predict == "scale":
            pred_val = run_qwen_scale(mb, img_path)
            gt_val = float(gt.get("scale", 1.0))
            pred = copy.deepcopy(gt)
            pred["scale"] = float(pred_val)
        elif args.predict == "pos":
            if is_two:
                pred_val = run_qwen_pos2(mb, img_path)
                gt_val = [gt["pieces"][0].get("pos",[0.0,0.0]), gt["pieces"][1].get("pos",[0.0,0.0])]
                pred = copy.deepcopy(gt)
                pred["pieces"][0]["pos"] = [float(pred_val[0][0]), float(pred_val[0][1])]
                pred["pieces"][1]["pos"] = [float(pred_val[1][0]), float(pred_val[1][1])]
            else:
                pred_val = run_qwen_pos(mb, img_path)
                gt_val = gt.get("pos", [0.0, 0.0])
                pred = copy.deepcopy(gt)
                pred["pos"] = [float(pred_val[0]), float(pred_val[1])]
        elif args.predict == "angle+pos":
            if is_two:
                pred_val = run_qwen_angle_pos2(mb, img_path)
                gt_val_angle = [float(gt["pieces"][0].get("angle",0.0)), float(gt["pieces"][1].get("angle",0.0))]
                gt_val_pos   = [gt["pieces"][0].get("pos",[0.0,0.0]), gt["pieces"][1].get("pos",[0.0,0.0])]
                pred = copy.deepcopy(gt)
                # angles
                ang_list = pred_val.get("angle", [0.0, 0.0])
                pred["pieces"][0]["angle"] = float(ang_list[0])
                pred["pieces"][1]["angle"] = float(ang_list[1])
                # positions
                pos_list = pred_val.get("pos", [[0.0,0.0],[0.0,0.0]])
                pred["pieces"][0]["pos"] = [float(pos_list[0][0]), float(pos_list[0][1])]
                pred["pieces"][1]["pos"] = [float(pos_list[1][0]), float(pos_list[1][1])]
            else:
                pred_val = run_qwen_angle_pos(mb, img_path)
                gt_val_angle = float(gt.get("angle", 0.0))
                gt_val_pos = gt.get("pos", [0.0, 0.0])
                pred = copy.deepcopy(gt)
                pred["angle"] = float(pred_val.get("angle", 0.0))
                pos_val = pred_val.get("pos", [0.0, 0.0])
                pred["pos"] = [float(pos_val[0]), float(pos_val[1])]
        else:
            raise ValueError("unsupported predict field")

        # 3) 保存 pred.json
        pred_json_path = os.path.join(work, "pred.json")
        with open(pred_json_path, "w", encoding="utf-8") as f:
            json.dump(pred, f, ensure_ascii=False, indent=2)

        # 4) 渲染 & 计算 IoU（基于渲染图）
        gt_render, pred_render, overlay_png = run_geometry_render(gt_json_path, pred_json_path, work)
        # iou = iou_from_pair(gt_render, pred_render)  # removed per instruction

        # 5) 误差
        if args.predict == "angle":
            if is_two:
                def _delta(a,b):
                    d = abs(float(a)-float(b)) % 360.0
                    return 360.0 - d if d > 180.0 else d
                err0 = _delta(pred_val[0], gt_val[0])
                err1 = _delta(pred_val[1], gt_val[1])
                err_mean = np.mean([err0, err1])
                rows.append((name, gt_val[0], gt_val[1], float(pred_val[0]), float(pred_val[1]), err_mean, err0, err1, None, None))
                print(f"[DONE {idx:02d}] {name} | GT_angle=[{gt_val[0]:.4f},{gt_val[1]:.4f}]  "
                      f"PRED_angle=[{float(pred_val[0]):.4f},{float(pred_val[1]):.4f}]  ANG_ERR_mean={err_mean:.4f} "
                      f"(e0={err0:.4f}, e1={err1:.4f}) -> {overlay_png}")
                err_list.append(err_mean)
            else:
                def angle_delta_deg(a: float, b: float) -> float:
                    d = abs(a - b) % 360.0
                    if d > 180.0:
                        d = 360.0 - d
                    return d
                err = angle_delta_deg(pred_val, gt_val)
                err_name = "ANG_ERR(°)"
                abs_val_diff = abs(float(pred_val) - float(gt_val))
                scale_ratio = None
                rows.append((name, gt_val, pred_val, err, abs_val_diff, None, scale_ratio))
                print(f"[DONE {idx:02d}] {name} | GT_angle={gt_val:.4f}  PRED_angle={pred_val:.4f}  "
                      f"{err_name}={err:.4f}  ABS_DIFF={abs_val_diff:.4f}  -> {overlay_png}")
                err_list.append(err)
        elif args.predict == "scale":
            err = abs(float(pred_val) - float(gt_val))
            err_name = "SCALE_ERR"
            abs_val_diff = abs(float(pred_val) - float(gt_val))
            scale_ratio = None
            if gt_val != 0:
                scale_ratio = float(pred_val) / float(gt_val)
                scale_ratio_list.append(scale_ratio)
            rows.append((name, gt_val, pred_val, err, abs_val_diff, None, scale_ratio))
            if scale_ratio is not None:
                print(
                    f"[DONE {idx:02d}] {name} | GT_scale={gt_val:.4f}  "
                    f"PRED_scale={pred_val:.4f}  {err_name}={err:.4f}  ABS_DIFF={abs_val_diff:.4f}  "
                    f"SCALE_RATIO(P/G)={scale_ratio:.4f}  -> {overlay_png}"
                )
            else:
                print(
                    f"[DONE {idx:02d}] {name} | GT_scale={gt_val:.4f}  "
                    f"PRED_scale={pred_val:.4f}  {err_name}={err:.4f}  ABS_DIFF={abs_val_diff:.4f}  "
                    f"-> {overlay_png}"
                )
            err_list.append(err)
        elif args.predict == "pos":
            if is_two:
                pred_xy0 = pred_val[0]
                pred_xy1 = pred_val[1]
                gt_xy0 = gt_val[0]
                gt_xy1 = gt_val[1]
                err0 = math.sqrt((float(pred_xy0[0]) - float(gt_xy0[0])) ** 2 + (float(pred_xy0[1]) - float(gt_xy0[1])) ** 2)
                err1 = math.sqrt((float(pred_xy1[0]) - float(gt_xy1[0])) ** 2 + (float(pred_xy1[1]) - float(gt_xy1[1])) ** 2)
                err_mean = np.mean([err0, err1])
                rows.append((name, gt_xy0, gt_xy1, pred_xy0, pred_xy1, err_mean, err0, err1, None, None))
                print(
                    f"[DONE {idx:02d}] {name} | GT_pos=[{gt_xy0},{gt_xy1}]  PRED_pos=[{pred_xy0},{pred_xy1}]  POS_ERR_mean={err_mean:.4f} (e0={err0:.4f}, e1={err1:.4f}) -> {overlay_png}"
                )
                err_list.append(err_mean)
            else:
                pred_xy = pred_val
                gt_xy = gt_val
                err = math.sqrt((float(pred_xy[0]) - float(gt_xy[0])) ** 2 + (float(pred_xy[1]) - float(gt_xy[1])) ** 2)
                err_name = "POS_ERR(L2)"
                abs_val_diff = err
                scale_ratio = None
                rows.append((name, gt_xy, pred_xy, err, abs_val_diff, None, scale_ratio))
                print(
                    f"[DONE {idx:02d}] {name} | GT_pos={gt_xy}  "
                    f"PRED_pos={pred_xy}  {err_name}={err:.4f}  "
                    f"-> {overlay_png}"
                )
                err_list.append(err)
        elif args.predict == "angle+pos":
            if is_two:
                ang_pred = pred_val.get("angle", [0.0,0.0])
                pos_pred = pred_val.get("pos", [[0.0,0.0],[0.0,0.0]])
                ang_gt = gt_val_angle
                pos_gt = gt_val_pos
                # angle error per piece
                def angle_delta_deg(a: float, b: float) -> float:
                    d = abs(a - b) % 360.0
                    if d > 180.0:
                        d = 360.0 - d
                    return d
                ang_err0 = angle_delta_deg(float(ang_pred[0]), float(ang_gt[0]))
                ang_err1 = angle_delta_deg(float(ang_pred[1]), float(ang_gt[1]))
                pos_err0 = math.sqrt((float(pos_pred[0][0]) - float(pos_gt[0][0])) ** 2 + (float(pos_pred[0][1]) - float(pos_gt[0][1])) ** 2)
                pos_err1 = math.sqrt((float(pos_pred[1][0]) - float(pos_gt[1][0])) ** 2 + (float(pos_pred[1][1]) - float(pos_gt[1][1])) ** 2)
                rows.append((name, ang_gt, ang_pred, np.mean([ang_err0,ang_err1]),
                             pos_gt, pos_pred, np.mean([pos_err0,pos_err1]),
                             ang_err0, ang_err1, pos_err0, pos_err1))
                print(
                    f"[DONE {idx:02d}] {name} | GT_angle={ang_gt}  PRED_angle={ang_pred}  ANG_ERR_mean={np.mean([ang_err0,ang_err1]):.4f} (e0={ang_err0:.4f}, e1={ang_err1:.4f}) | "
                    f"GT_pos={pos_gt}  PRED_pos={pos_pred}  POS_ERR_mean={np.mean([pos_err0,pos_err1]):.4f} (e0={pos_err0:.4f}, e1={pos_err1:.4f}) -> {overlay_png}"
                )
                err_list.append(np.mean([ang_err0, ang_err1]))
            else:
                pred_angle = float(pred_val.get("angle", 0.0))
                pred_xy = pred_val.get("pos", [0.0, 0.0])
                gt_angle = float(gt.get("angle", 0.0))
                gt_xy = gt.get("pos", [0.0, 0.0])
                # angle error
                def angle_delta_deg(a: float, b: float) -> float:
                    d = abs(a - b) % 360.0
                    if d > 180.0:
                        d = 360.0 - d
                    return d
                angle_err = angle_delta_deg(pred_angle, gt_angle)
                pos_err = math.sqrt((float(pred_xy[0]) - float(gt_xy[0])) ** 2 + (float(pred_xy[1]) - float(gt_xy[1])) ** 2)
                err_name = "ANG_ERR(°)"
                pos_err_name = "POS_ERR(L2)"
                abs_val_diff = None
                scale_ratio = None
                rows.append((name, gt_angle, pred_angle, angle_err, gt_xy, pred_xy, pos_err, None, scale_ratio))
                print(
                    f"[DONE {idx:02d}] {name} | GT_angle={gt_angle:.4f}  PRED_angle={pred_angle:.4f}  {err_name}={angle_err:.4f} | "
                    f"GT_pos={gt_xy}  PRED_pos={pred_xy}  {pos_err_name}={pos_err:.4f}  -> {overlay_png}"
                )
                err_list.append(angle_err)
        else:
            raise ValueError("unsupported predict field")

    # Use unified overlay-based evaluator for IoU
    summary_csv_path, iou_map = run_external_iouEvaluator(batch_run_dir=root, gt_dir=args.json_dir)

    # Merge IoU into rows
    final_rows = []
    iou_vals = []
    # Determine if two-piece for csv output
    any_two = False
    for r in rows:
        if (args.predict == "pos" and len(r) >= 10) or (args.predict == "angle" and len(r) >= 9) or (args.predict == "angle+pos" and len(r) >= 11):
            any_two = True
            break

    # Define err_name for summary printing
    if args.predict == "angle":
        err_name = "ANG_ERR(°)"
    elif args.predict == "scale":
        err_name = "SCALE_ERR"
    elif args.predict == "pos":
        err_name = "POS_ERR_mean" if any_two else "POS_ERR(L2)"
    elif args.predict == "angle+pos":
        err_name = "ANG_ERR(°)"  # not used in two-piece branch but defined for safety
    else:
        err_name = "ERR"

    if args.predict == "angle+pos":
        if any_two:
            for (name, gt_angle, pred_angle, ang_err_mean, gt_pos, pred_pos, pos_err_mean, ang_err0, ang_err1, pos_err0, pos_err1) in rows:
                iou = iou_map.get(name)
                if iou is not None:
                    iou_vals.append(iou)
                final_rows.append((name, gt_angle, pred_angle, ang_err_mean, gt_pos, pred_pos, pos_err_mean, ang_err0, ang_err1, pos_err0, pos_err1, iou))
        else:
            for (name, gt_angle, pred_angle, angle_err, gt_xy, pred_xy, pos_err, _, scale_ratio) in rows:
                iou = iou_map.get(name)
                if iou is not None:
                    iou_vals.append(iou)
                final_rows.append((name, gt_angle, pred_angle, angle_err, gt_xy, pred_xy, pos_err, iou))
    elif args.predict == "angle":
        if any_two:
            for (name, gt0, gt1, pred0, pred1, err_mean, err0, err1, _, _) in rows:
                iou = iou_map.get(name)
                if iou is not None:
                    iou_vals.append(iou)
                final_rows.append((name, gt0, gt1, pred0, pred1, err_mean, err0, err1, iou))
        else:
            for (name, gt_val, pred_val, err, abs_val_diff, _, scale_ratio) in rows:
                iou = iou_map.get(name)
                if iou is not None:
                    iou_vals.append(iou)
                final_rows.append((name, gt_val, pred_val, err, abs_val_diff, iou, scale_ratio))
    else:
        for (name, gt_val, pred_val, err, abs_val_diff, _, scale_ratio) in rows:
            iou = iou_map.get(name)
            if iou is not None:
                iou_vals.append(iou)
            final_rows.append((name, gt_val, pred_val, err, abs_val_diff, iou, scale_ratio))

    print("=" * 60)
    if iou_vals:
        if args.predict == "angle+pos":
            mean_angle_err = np.mean([r[3] for r in final_rows])
            median_angle_err = np.median([r[3] for r in final_rows])
            mean_pos_err = np.mean([r[6] for r in final_rows])
            median_pos_err = np.median([r[6] for r in final_rows])
            print(f"[SUMMARY] N={len(final_rows)}  mean IoU={np.mean(iou_vals):.4f}  median IoU={np.median(iou_vals):.4f}  "
                  f"mean ANG_ERR={mean_angle_err:.4f}  median ANG_ERR={median_angle_err:.4f}  "
                  f"mean POS_ERR={mean_pos_err:.4f}  median POS_ERR={median_pos_err:.4f}")
        else:
            print(f"[SUMMARY] N={len(final_rows)}  mean IoU={np.mean(iou_vals):.4f}  median IoU={np.median(iou_vals):.4f}  "
                  f"mean {err_name}={np.mean(err_list):.4f}  median {err_name}={np.median(err_list):.4f}")
    else:
        if args.predict == "angle+pos":
            mean_angle_err = np.mean([r[3] for r in final_rows])
            median_angle_err = np.median([r[3] for r in final_rows])
            mean_pos_err = np.mean([r[6] for r in final_rows])
            median_pos_err = np.median([r[6] for r in final_rows])
            print(f"[SUMMARY] N={len(final_rows)}  (IoU not available from overlay evaluator)  "
                  f"mean ANG_ERR={mean_angle_err:.4f}  median ANG_ERR={median_angle_err:.4f}  "
                  f"mean POS_ERR={mean_pos_err:.4f}  median POS_ERR={median_pos_err:.4f}")
        else:
            print(f"[SUMMARY] N={len(final_rows)}  (IoU not available from overlay evaluator)  "
                  f"mean {err_name}={np.mean(err_list):.4f}  median {err_name}={np.median(err_list):.4f}")
    print("=" * 60)

    csv_path = os.path.join(root, "results_table.csv")
    with open(csv_path, "w", newline='', encoding="utf-8") as csvfile:
        writer = csv.writer(csvfile)
        if args.predict == "angle" and any_two:
            header = ["sample","GT_angle[0]","GT_angle[1]","PRED_angle[0]","PRED_angle[1]","ANG_ERR_mean","ANG_ERR_0","ANG_ERR_1","IoU"]
            writer.writerow(header)
            for row in final_rows:
                outrow = [row[0],
                          f"{row[1]:.4f}", f"{row[2]:.4f}",
                          f"{row[3]:.4f}", f"{row[4]:.4f}",
                          f"{row[5]:.4f}", f"{row[6]:.4f}", f"{row[7]:.4f}",
                          f"{row[8]:.4f}" if row[8] is not None else ""]
                writer.writerow(outrow)
        elif args.predict == "angle+pos" and any_two:
            header = ["sample","GT_angle","PRED_angle","ANG_ERR_mean","GT_pos","PRED_pos","POS_ERR_mean","ANG_ERR_0","ANG_ERR_1","POS_ERR_0","POS_ERR_1","IoU"]
            writer.writerow(header)
            for row in final_rows:
                # row = (name, gt_angle, pred_angle, ang_err_mean, gt_pos, pred_pos, pos_err_mean, ang_err0, ang_err1, pos_err0, pos_err1, iou)
                outrow = [row[0],
                          str(row[1]), str(row[2]), f"{row[3]:.4f}",
                          str(row[4]), str(row[5]), f"{row[6]:.4f}",
                          f"{row[7]:.4f}", f"{row[8]:.4f}", f"{row[9]:.4f}", f"{row[10]:.4f}", f"{row[11]:.4f}" if row[11] is not None else ""]
                writer.writerow(outrow)
        elif args.predict == "angle+pos":
            header = ["sample", "GT_angle", "PRED_angle", "ANG_ERR(°)", "GT_pos", "PRED_pos", "POS_ERR(L2)", "IoU"]
            writer.writerow(header)
            for row in final_rows:
                writer.writerow(row)
        else:
            header = ["sample", f"GT_{args.predict}", f"PRED_{args.predict}", err_name, "ABS_DIFF", "IoU"]
            if args.predict == "scale":
                header.append("SCALE_RATIO")
            writer.writerow(header)
            for row in final_rows:
                # row = (name, gt_val, pred_val, err, abs_val_diff, iou, scale_ratio)
                if args.predict == "scale":
                    writer.writerow(row)
                else:
                    writer.writerow(row[:-1])  # drop ratio column when not scale
    print(f"[INFO] Results table saved to: {csv_path}")


if __name__ == "__main__":
    main()