#!/usr/bin/env python3
# -*- coding: utf-8 -*-
"""
qwen_geom_param_demo.py  (replaces qwen_angle_demo.py)
- 通用评测脚本：固定 GT 中除一个目标参数外的所有字段，只让 Qwen 预测该参数
- 目前支持：angle（角度，度） 与 scale（缩放因子）
- 预测后把该字段写入 pred.json（其余字段直接继承 GT）
- 用 geometry.py 渲染 GT / Pred / Overlay
- 通过渲染图阈值化计算 IoU，在终端逐样本与汇总打印
"""

import os
import re
import sys
import cv2
import json
import time
import copy 
import glob
import math
import argparse
import subprocess
from dataclasses import dataclass
from typing import Dict, Any, Tuple, List

import numpy as np
from PIL import Image
import torch
import csv

# 让 tokenizers 安静点（避免 fork 警告）
os.environ.setdefault("TOKENIZERS_PARALLELISM", "false")

# ===== 模型加载（Qwen2.5-VL 本地权重） =====
from transformers import AutoModelForVision2Seq, AutoModelForCausalLM, AutoTokenizer, AutoProcessor

SCRIPT_DIR = os.path.dirname(os.path.abspath(__file__))
GEOMETRY_PY = os.path.join(SCRIPT_DIR, "geometry.py")


@dataclass
class ModelBundle:
    tok: Any
    proc: Any
    mdl: Any


def _resolve_qwen_snapshot(base_dir: str) -> str:
    """
    Hugging Face cache layout:
    ~/.cache/huggingface/hub/models--ORG--NAME/{blobs,refs,snapshots/<hash>}
    We need the real snapshot folder that contains config.json, model*.bin, etc.
    """
    snaps = os.path.join(base_dir, "snapshots")
    if os.path.isdir(snaps):
        # pick the most recent snapshot by mtime
        cands = [os.path.join(snaps, d) for d in os.listdir(snaps) if os.path.isdir(os.path.join(snaps, d))]
        if cands:
            cands.sort(key=lambda p: os.path.getmtime(p), reverse=True)
            return cands[0]
    # fall back to base_dir (in case of symlinked/flattened cache)
    return base_dir


def load_qwen(model_name: str = None, device: str = "cpu") -> ModelBundle:
    print("[INFO] Loading local Qwen2.5-VL-3B-Instruct (no download)...")
    base_dir = os.path.expanduser("~/.cache/huggingface/hub/models--Qwen--Qwen2.5-VL-3B-Instruct")
    # Allow override via env var if user moved the cache
    base_dir = os.environ.get("QWEN_LOCAL_DIR", base_dir)
    local_dir = _resolve_qwen_snapshot(base_dir)
    print(f"[INFO] Using snapshot: {local_dir}")

    use_cpu = (str(device).lower() == "cpu")
    if not use_cpu and str(device).lower() == "mps":
        os.environ.setdefault("PYTORCH_ENABLE_MPS_FALLBACK", "1")

    tok = AutoTokenizer.from_pretrained(local_dir, use_fast=True, trust_remote_code=True, local_files_only=True)
    proc = AutoProcessor.from_pretrained(local_dir, trust_remote_code=True, local_files_only=True)

    # Try Vision2Seq first (newer Qwen2.5-VL uses this); fall back to CausalLM if needed.
    try:
        mdl = AutoModelForVision2Seq.from_pretrained(
            local_dir,
            device_map=None,
            torch_dtype=torch.float32,
            trust_remote_code=True,
            local_files_only=True,
        )
        mdl.to("cpu" if use_cpu else "mps")
        print("[INFO] Loaded with AutoModelForVision2Seq (local)")
    except Exception as e1:
        print(f"[WARN] Vision2Seq loader failed: {e1}\n[INFO] Falling back to AutoModelForCausalLM ...")
        mdl = AutoModelForCausalLM.from_pretrained(
            local_dir,
            device_map=None,
            torch_dtype=torch.float32,
            trust_remote_code=True,
            local_files_only=True,
        )
        mdl.to("cpu" if use_cpu else "mps")
        print("[INFO] Loaded with AutoModelForCausalLM (local fallback)")
    mdl.eval()
    return ModelBundle(tok=tok, proc=proc, mdl=mdl)


# =========================
#     统一的推理函数
# =========================

ANGLE_SYS = "You are a vision assistant for geometry evaluation."
ANGLE_USER = (
    "You are given an image of a single geometric piece inside a 0..10 grid.\n"
    "Your job: RETURN ONLY a compact JSON with the rotation angle of the piece (in degrees).\n"
    "Convention: positive=counter-clockwise, negative=clockwise; range [-180, 180]; 0 means upright.\n"
    "Respond EXACTLY like: {\"angle\": <float>}"
)

SCALE_SYS = "You are a vision assistant for geometry evaluation."
SCALE_USER = (
    "You are given an image of a single geometric piece inside a 0..10 grid.\n"
    "Your job: RETURN ONLY a compact JSON with the size scale factor of the piece.\n"
    "Interpret 'scale' as the multiplicative factor applied to a canonical template to match the piece silhouette.\n"
    "Respond EXACTLY like: {\"scale\": <float>}  (no extra keys or text)\n"
    "Typical values are around 0.5~1.2; do not include units."
)


def _chat_generate_json(mbundle: ModelBundle, image_path: str, sys_prompt: str, user_prompt: str,
                        key_name: str, max_new_tokens: int = 64) -> float:
    """通用：让 Qwen 输出 {"<key_name>": number}，并返回该数值（失败则返回 0.0）"""
    messages = [
        {"role": "system", "content": sys_prompt},
        {"role": "user", "content": [
            {"type": "image", "image": image_path},
            {"type": "text", "text": user_prompt},
        ]},
    ]
    conversation = mbundle.tok.apply_chat_template(messages, add_generation_prompt=True, tokenize=False)
    inputs = mbundle.proc(
        text=[conversation],
        images=[Image.open(image_path).convert("RGB")],
        return_tensors="pt"
    )
    dev = next(mbundle.mdl.parameters()).device
    inputs = inputs.to(dev)
    gen = mbundle.mdl.generate(
        **inputs,
        max_new_tokens=max_new_tokens,
        do_sample=False,
        temperature=0.0,
        repetition_penalty=1.0,
    )
    out_text = mbundle.tok.decode(gen[0], skip_special_tokens=True)
    # 捕获 {"key": number}
    pattern = rf'\{{\s*"{re.escape(key_name)}"\s*:\s*([\-+]?\d+(\.\d+)?)\s*\}}'
    m = re.search(pattern, out_text)
    if not m:
        print(f"[WARN] Qwen {key_name} JSON not found, fallback 0.")
        return 0.0
    val = float(m.group(1))
    if key_name == "angle":
        # 角度归一化到 [-180, 180]
        val = ((val + 180.0) % 360.0) - 180.0
    return val


def run_qwen_angle(mbundle: ModelBundle, image_path: str) -> float:
    return _chat_generate_json(mbundle, image_path, ANGLE_SYS, ANGLE_USER, "angle")


def run_qwen_scale(mbundle: ModelBundle, image_path: str) -> float:
    return _chat_generate_json(mbundle, image_path, SCALE_SYS, SCALE_USER, "scale")


# ========== 数据与工具 ==========

def list_pairs(img_dir: str, json_dir: str, exts=(".png", ".jpg", ".jpeg")) -> List[Tuple[str, str]]:
    pairs = []
    for p in sorted(glob.glob(os.path.join(img_dir, "*"))):
        if not p.lower().endswith(exts):
            continue
        base = os.path.basename(p)
        name = os.path.splitext(base)[0]
        jp = os.path.join(json_dir, f"{name}.json")
        if os.path.exists(jp):
            pairs.append((p, jp))
    return pairs


def must_exist(path: str):
    if not os.path.exists(path):
        raise FileNotFoundError(f"[PATH ERROR] Not found: {path}")


def load_json(path: str) -> Dict[str, Any]:
    with open(path, "r", encoding="utf-8") as f:
        return json.load(f)


def save_json(obj: Dict[str, Any], path: str):
    with open(path, "w", encoding="utf-8") as f:
        json.dump(obj, f, ensure_ascii=False, indent=2)


def ensure_dir(d: str):
    os.makedirs(d, exist_ok=True)


def angle_delta_deg(a: float, b: float) -> float:
    d = abs(a - b) % 360.0
    if d > 180.0:
        d = 360.0 - d
    return d


# ========== 渲染 & IoU ==========
# 通过 geometry.py 生成：
#   - gt_render.png
#   - pred_render.png
#   - overlay_json_vs_pred.png
# 然后用 gt/pred 两张渲染图做阈值化求 IoU

def run_geometry_render(gt_json_path: str, pred_json_path: str, out_dir: str):
    ensure_dir(out_dir)

    gtr = os.path.join(out_dir, "gt_render.png")
    prd = os.path.join(out_dir, "pred_render.png")
    ovl = os.path.join(out_dir, "overlay_json_vs_pred.png")

    # 1) GT
    subprocess.run(
        [
            sys.executable, GEOMETRY_PY,
            "--ann", gt_json_path,
            "--save_png", gtr,
            "--no_show",
        ],
        check=True
    )
    # 2) Pred
    subprocess.run(
        [
            sys.executable, GEOMETRY_PY,
            "--ann", pred_json_path,
            "--save_png", prd,
            "--no_show",
        ],
        check=True
    )
    # 3) Overlay
    subprocess.run(
        [
            sys.executable, GEOMETRY_PY,
            "--overlay_gt_json", gt_json_path,
            "--overlay_pred_json", pred_json_path,
            "--save_png", ovl,
            "--no_show",
        ],
        check=True
    )
    return gtr, prd, ovl


def _to_mask(img: Image.Image) -> np.ndarray:
    """
    将 geometry 渲染出的图转为 0/1 mask。
    经验阈值：图里前景是“填充的多边形”，一般比背景更亮（或更彩）。
    用 HSV 的 V 分量做阈值（> 0.6）。若配色有变化，可调整阈值。
    """
    im = np.array(img.convert("RGB"))
    hsv = cv2.cvtColor(im, cv2.COLOR_RGB2HSV).astype(np.float32) / 255.0
    v = hsv[..., 2]
    mask = (v > 0.6).astype(np.uint8)
    # 取最大连通块以去掉网格/文字干扰
    num, lbl = cv2.connectedComponents(mask)
    if num <= 1:
        return mask
    best = 0
    best_area = 0
    for i in range(1, num):
        area = (lbl == i).sum()
        if area > best_area:
            best_area = area
            best = i
    return (lbl == best).astype(np.uint8)


def iou_from_pair(gt_png: str, pred_png: str) -> float:
    gt = Image.open(gt_png)
    pr = Image.open(pred_png)

    if gt.size != pr.size:
        pr = pr.resize(gt.size, Image.NEAREST)

    m1 = _to_mask(gt)
    m2 = _to_mask(pr)

    inter = np.logical_and(m1, m2).sum()
    union = np.logical_or(m1, m2).sum()
    if union == 0:
        return 0.0
    return inter / float(union)


def run_external_iouEvaluator(batch_run_dir: str, gt_dir: str) -> Tuple[str, Dict[str, float]]:
    """Run iou_from_overlay.py with the unified flags and return (csv_path, {sample->IoU})."""
    summary_csv_path = os.path.join(batch_run_dir, "iou_summary.csv")
    cmd = [
        sys.executable,
        os.path.join(SCRIPT_DIR, "iou_from_overlay.py"),
        "--batch_my_run", batch_run_dir,
        "--gt_dir", gt_dir,
        "--map_to_axes_geom",
        "--gt_anchor", "centroid",
        "--pred_anchor", "centroid",
        "--size_mode_geom", "keep",
        "--dilate", "2",
        "--write_iou_txt",
        "--summary_csv", summary_csv_path,
        "--save_debug",
        "--verbose",
    ]
    print("[EVAL] Running overlay IoU evaluator...\n ", " ".join(cmd))
    subprocess.run(cmd, check=True)
    print(f"[EVAL] Overlay-based IoU summary written to: {summary_csv_path}")

    # Parse summary CSV (column names may vary slightly; try common variants)
    iou_map: Dict[str, float] = {}
    try:
        with open(summary_csv_path, "r", encoding="utf-8") as f:
            reader = csv.DictReader(f)
            # guess column names
            sample_keys = [k for k in reader.fieldnames or [] if k.lower() in ("sample", "name", "id", "file", "basename")]
            iou_keys = [k for k in reader.fieldnames or [] if k.lower() in ("iou", "iou_score", "mean_iou", "iou")]
            sample_key = sample_keys[0] if sample_keys else None
            iou_key = iou_keys[0] if iou_keys else None
            for row in reader:
                key = row.get(sample_key) if sample_key else row.get("sample") or row.get("name")
                iou_str = row.get(iou_key) if iou_key else row.get("IoU") or row.get("iou")
                if key is None or iou_str is None:
                    continue
                try:
                    iou_map[os.path.splitext(os.path.basename(key))[0]] = float(iou_str)
                except Exception:
                    pass
    except Exception as e:
        print(f"[WARN] Failed to parse IoU summary CSV: {e}")
    return summary_csv_path, iou_map


# ========== 主流程 ==========

def main():
    ap = argparse.ArgumentParser()
    ap.add_argument("--img_dir", type=str, required=True, help="图片目录（与 JSON 同名）")
    ap.add_argument("--json_dir", type=str, required=True, help="GT JSON 目录")
    ap.add_argument("--n", type=int, default=8, help="最多多少个样本（--all 覆盖）")
    ap.add_argument("--seed", type=int, default=2025)
    ap.add_argument("--all", action="store_true", help="处理目录下所有样本")
    ap.add_argument("--out_dir", type=str, default="./runs/param_eval")
    ap.add_argument("--model", type=str, default="Qwen/Qwen2.5-VL-7B-Instruct")
    ap.add_argument("--device", type=str, default="cpu", choices=["cpu", "mps"], help="force device; default cpu to avoid MPS matmul bugs")

    # 关键：预测哪个字段
    ap.add_argument("--predict", type=str, default="angle", choices=["angle", "scale", "size"],
                    help="只预测该字段，其余字段固定为 GT。'size' 为 'scale' 同义词。")

    args = ap.parse_args()
    if args.predict == "size":
        args.predict = "scale"

    must_exist(args.img_dir)
    must_exist(args.json_dir)
    ensure_dir(args.out_dir)

    pairs = list_pairs(args.img_dir, args.json_dir)
    if not pairs:
        raise FileNotFoundError(f"No pairs in {args.img_dir} vs {args.json_dir}")
    if not args.all:
        pairs = pairs[: args.n]

    ts = time.strftime("%Y%m%d_%H%M%S")
    root = os.path.join(args.out_dir, ts)
    ensure_dir(root)
    print(f"[OUT] results -> {root}")
    print(f"[TASK] predict = {args.predict}  (others fixed from GT)")

    # 载入模型
    mb = load_qwen(args.model, device=args.device)

    err_list: List[float] = []
    rows = []
    scale_ratio_list: List[float] = []

    for idx, (img_path, gt_json_path) in enumerate(pairs, 1):
        name = os.path.splitext(os.path.basename(img_path))[0]
        work = os.path.join(root, name)
        ensure_dir(work)

        gt = load_json(gt_json_path)

        # 1) 让 Qwen 只预测目标字段
        if args.predict == "angle":
            pred_val = run_qwen_angle(mb, img_path)
            gt_val = float(gt.get("angle", 0.0))
        elif args.predict == "scale":
            pred_val = run_qwen_scale(mb, img_path)
            gt_val = float(gt.get("scale", 1.0))
        else:
            raise ValueError("unsupported predict field")

        # 2) 构造 pred：其它字段全部沿用 GT
        pred = copy.deepcopy(gt)
        pred[args.predict] = float(pred_val)

        # 3) 保存 pred.json
        pred_json_path = os.path.join(work, "pred.json")
        with open(pred_json_path, "w", encoding="utf-8") as f:
            json.dump(pred, f, ensure_ascii=False, indent=2)

        # 4) 渲染 & 计算 IoU（基于渲染图）
        gt_render, pred_render, overlay_png = run_geometry_render(gt_json_path, pred_json_path, work)
        # iou = iou_from_pair(gt_render, pred_render)  # removed per instruction

        # 5) 误差（角度或 scale）
        if args.predict == "angle":
            def angle_delta_deg(a: float, b: float) -> float:
                d = abs(a - b) % 360.0
                if d > 180.0:
                    d = 360.0 - d
                return d
            err = angle_delta_deg(pred_val, gt_val)
            err_name = "ANG_ERR(°)"
        else:
            err = abs(float(pred_val) - float(gt_val))
            err_name = "SCALE_ERR"

        abs_val_diff = abs(float(pred_val) - float(gt_val))

        # 额外指标：SCALE_RATIO = PRED/GT（仅在预测scale时计算）
        scale_ratio = None
        if args.predict == "scale":
            if gt_val != 0:
                scale_ratio = float(pred_val) / float(gt_val)
                scale_ratio_list.append(scale_ratio)
            else:
                scale_ratio = None

        err_list.append(err)

        rows.append((name, gt_val, pred_val, err, abs_val_diff, None, scale_ratio))

        if args.predict == "scale" and scale_ratio is not None:
            print(
                f"[DONE {idx:02d}] {name} | GT_{args.predict}={gt_val:.4f}  "
                f"PRED_{args.predict}={pred_val:.4f}  {err_name}={err:.4f}  ABS_DIFF={abs_val_diff:.4f}  "
                f"SCALE_RATIO(P/G)={scale_ratio:.4f}  -> {overlay_png}"
            )
        else:
            print(
                f"[DONE {idx:02d}] {name} | GT_{args.predict}={gt_val:.4f}  "
                f"PRED_{args.predict}={pred_val:.4f}  {err_name}={err:.4f}  ABS_DIFF={abs_val_diff:.4f}  "
                f"-> {overlay_png}"
            )

    # Use unified overlay-based evaluator for IoU
    summary_csv_path, iou_map = run_external_iouEvaluator(batch_run_dir=root, gt_dir=args.json_dir)

    # Merge IoU into rows
    final_rows = []
    iou_vals = []
    for (name, gt_val, pred_val, err, abs_val_diff, _, scale_ratio) in rows:
        iou = iou_map.get(name)
        if iou is not None:
            iou_vals.append(iou)
        final_rows.append((name, gt_val, pred_val, err, abs_val_diff, iou, scale_ratio))

    print("=" * 60)
    if iou_vals:
        print(f"[SUMMARY] N={len(final_rows)}  mean IoU={np.mean(iou_vals):.4f}  median IoU={np.median(iou_vals):.4f}  "
              f"mean {err_name}={np.mean(err_list):.4f}  median {err_name}={np.median(err_list):.4f}")
    else:
        print(f"[SUMMARY] N={len(final_rows)}  (IoU not available from overlay evaluator)  "
              f"mean {err_name}={np.mean(err_list):.4f}  median {err_name}={np.median(err_list):.4f}")
    print("=" * 60)

    csv_path = os.path.join(root, "results_table.csv")
    with open(csv_path, "w", newline='', encoding="utf-8") as csvfile:
        writer = csv.writer(csvfile)
        header = ["sample", f"GT_{args.predict}", f"PRED_{args.predict}", err_name, "ABS_DIFF", "IoU"]
        if args.predict == "scale":
            header.append("SCALE_RATIO")
        writer.writerow(header)
        for row in final_rows:
            # row = (name, gt_val, pred_val, err, abs_val_diff, iou, scale_ratio)
            if args.predict == "scale":
                writer.writerow(row)
            else:
                writer.writerow(row[:-1])  # drop ratio column when not scale
    print(f"[INFO] Results table saved to: {csv_path}")


if __name__ == "__main__":
    main()