import os, re, json, math, argparse, random, textwrap
import numpy as np
from PIL import Image, ImageDraw, ImageOps
import torch
from transformers import Qwen2_5_VLForConditionalGeneration, AutoProcessor
from qwen_vl_utils import process_vision_info
import matplotlib.pyplot as plt
import csv
import os.path as osp
from copy import deepcopy

# ----------------- Helpers (missing in this file) -----------------
def must_exist(p: str):
    """确保路径存在，否则抛出更清晰的错误。"""
    if not os.path.exists(p):
        raise FileNotFoundError(p)

def round_floats(o):
    """递归地把浮点数四舍五入，便于打印/对齐（可选）。"""
    if isinstance(o, dict):
        return {k: round_floats(v) for k, v in o.items()}
    if isinstance(o, list):
        return [round_floats(v) for v in o]
    if isinstance(o, float):
        return round(o, 6)
    return o

def load_json_str(path: str) -> str:
    """
    读取JSON并转成字符串（few-shot时要把GT JSON作为文本小抄喂给模型）。
    不做结构转换，只是原样dump为字符串。
    """
    with open(path, "r", encoding="utf-8") as f:
        raw = json.load(f)
    return json.dumps(round_floats(raw), ensure_ascii=False)

def strict_find_json(text: str) -> str | None:
    """
    从模型原始输出中提取最后一个平衡的JSON对象/数组片段。
    允许前后夹杂说明性文字；返回None表示未找到。
    """
    stack = []
    start = -1
    in_string = False
    escape = False
    last = None
    for i, ch in enumerate(text):
        if escape:
            escape = False
            continue
        if ch == '\\':
            escape = True
            continue
        if ch == '"':
            in_string = not in_string
            continue
        if in_string:
            continue
        if ch in '{[':
            if not stack:
                start = i
            stack.append(ch)
        elif ch in '}]':
            if stack:
                stack.pop()
                if not stack and start != -1:
                    last = text[start:i+1]
    # 去掉围栏代码块 ```...```
    if last and last.strip().startswith("```"):
        import re as _re
        m = _re.search(r"```(?:json)?\\s*([\\s\\S]*?)```", last)
        if m:
            return m.group(1).strip()
    return last.strip() if last else None

# ----------------- Prompts -----------------
INSTR = textwrap.dedent("""
You are now a tangram puzzle solver.
Task: Output STRICT JSON, exactly 7 objects. Each object MUST have the following fields:
  name: one of [T1,T2,T3,T4,T5,P1,S1]
  type: one of [triangle,square,parallelogram]
  size: one of [large,medium,small,na]
  pos: [x,y] two floats (0..10 canvas grid)
  angle: float, MUST be chosen from {-135,-90,-45,0,45,90,135,180}
  flip: boolean
  scale: float
Return PURE JSON only, no extra text, no Markdown.

Constraints:
- Pieces must NOT share identical pos; avoid stacking at the same coordinates.
- Prefer non-zero angles when needed to match slanted edges.
- Pieces should not significantly overlap each other; use the full canvas area to cover the silhouette.
""").strip()

CHEATSHEET = """
Tangram 7 fixed pieces:
- T1: triangle large
- T2: triangle large
- T3: triangle medium
- T4: triangle small
- T5: triangle small
- P1: parallelogram na
- S1: square na
"""

# ----------------- Geometry utils（与 geometry.py 对齐） -----------------
# 统一规范：模型输出在 0..10 的网格坐标系；渲染/评估均在 512×512 像素画布上。
CANVAS_SIZE = 512  # 几何评估固定画布大小

def _poly_triangle(size):
    # 与 geometry.py 一致：直角等腰三角形，三种尺度
    base = {"large": 1.0, "medium": math.sqrt(2)/2.0, "small": 0.5}[size]
    return [(0.0, 0.0), (base, 0.0), (0.0, base)]

def _poly_square():
    return [(0.0, 0.0), (1.0, 0.0), (1.0, 1.0), (0.0, 1.0)]

def _poly_parallelogram():
    # 简单 1×0.5 平行四边形（与 geometry.py 保持一致）
    return [(0.0, 0.0), (1.0, 0.0), (1.5, 0.5), (0.5, 0.5)]

def _apply_transform(poly, pos, angle_deg, flip, scale, px):
    """
    变换顺序严格与 geometry.py 一致：
    1) 可选水平翻转（关于 x 轴镜像：x -> -x）
    2) 缩放
    3) 旋转（围绕原点）
    4) 平移
    5) 从 0..10 网格坐标映射到像素（乘以 px）
    """
    ax = math.radians(float(angle_deg))
    ca, sa = math.cos(ax), math.sin(ax)
    out = []
    for (x, y) in poly:
        if flip:
            x = -x
        x *= float(scale); y *= float(scale)
        xr = x * ca - y * sa
        yr = x * sa + y * ca
        X = (xr + float(pos[0])) * px
        Y = (yr + float(pos[1])) * px
        out.append((X, Y))
    return out

def _pieces_to_polys_in_pixels(pieces, w, h):
    # 使用最长边做像素刻度，确保纵横等比（与 geometry.py 相同）
    px = max(w, h) / 10.0
    polys = []
    for p in pieces:
        if p["type"] == "triangle":
            poly = _poly_triangle(p["size"])
        elif p["type"] == "square":
            poly = _poly_square()
        else:
            poly = _poly_parallelogram()
        polys.append(_apply_transform(poly, p["pos"], p["angle"], p["flip"], p["scale"], px))
    return polys

def render_mask(pieces, w=CANVAS_SIZE, h=CANVAS_SIZE):
    img = Image.new("L", (w, h), 0)
    draw = ImageDraw.Draw(img)
    for poly in _pieces_to_polys_in_pixels(pieces, w, h):
        draw.polygon(poly, fill=1)
    return np.array(img, dtype=np.uint8)

def load_silhouette(path, size=CANVAS_SIZE, thresh=128):
    """
    读取黑色轮廓 PNG，转换为二值前景掩码，并统一 resize 到 CANVAS_SIZE。
    """
    img = Image.open(path).convert("L").resize((size, size), Image.NEAREST)
    arr = np.array(img, dtype=np.uint8)
    # 假设黑底为前景
    return (arr < thresh).astype(np.uint8)

def geometry_metrics(pred_mask, gt_mask):
    pred = pred_mask.astype(bool); gt = gt_mask.astype(bool)
    inter = np.logical_and(pred, gt).sum()
    union = np.logical_or(pred, gt).sum()
    iou = inter / max(1, union)
    overflow = np.logical_and(pred, ~gt).sum() / pred.size
    undercov = np.logical_and(gt, ~pred).sum() / pred.size
    return iou, overflow, undercov

def save_overlay_on_outline(outline_png_path: str, pieces: list, save_path: str, size=CANVAS_SIZE):
    """
    将预测的多边形（像素坐标）叠加到原始轮廓图上进行直观可视化。
    关键点：使用图像坐标（左上为原点，y 轴向下），确保与 IoU 的坐标系一致。
    """
    # 读入并统一尺寸
    bg = Image.open(outline_png_path).convert("RGB").resize((size, size), Image.NEAREST)
    W, H = bg.size

    # 用 Matplotlib 叠加（因为需要半透明填充+边框）
    plt.figure(figsize=(5, 5))
    plt.imshow(bg)
    ax = plt.gca()

    # 坐标轴与图像一致：x 右增，y 向下
    ax.set_xlim(0, W)
    ax.set_ylim(H, 0)  # 关键：反转 y 轴与图像坐标一致
    ax.axis("off")

    # 将 pieces 转换成像素多边形并绘制
    polys = _pieces_to_polys_in_pixels(pieces, W, H)
    for poly in polys:
        xs = [p[0] for p in poly] + [poly[0][0]]
        ys = [p[1] for p in poly] + [poly[0][1]]
        ax.fill(xs, ys, alpha=0.30, linewidth=1.5, edgecolor="black")

    plt.tight_layout(pad=0)
    plt.savefig(save_path, dpi=160)
    plt.close()

# ----------------- 预测 JSON 清洗（强制严格 7 件套 & 字段合法） -----------------
ROSTER = {
    "T1": ("triangle", "large"),
    "T2": ("triangle", "large"),
    "T3": ("triangle", "medium"),
    "T4": ("triangle", "small"),
    "T5": ("triangle", "small"),
    "P1": ("parallelogram", "na"),
    "S1": ("square", "na"),
}
ANG_SET = [-135, -90, -45, 0, 45, 90, 135, 180]

def _closest_angle(a):
    # 把任意角度吸附到 45° 倍数（有助于贴合轮廓边）
    try:
        af = float(a)
    except Exception:
        return 0.0
    return min(ANG_SET, key=lambda x: abs(x - af))

def sanitize_pred(obj):
    """
    将模型输出的任意结构，清洗为严格 7 件套列表：
    - name 固定集合；type/size 强制匹配 ROSTER；
    - pos/angle/flip/scale 转为合法数据类型；
    - 丢弃多余字段；缺少则报错。
    """
    if isinstance(obj, dict) and "pieces" in obj:
        items = obj["pieces"]
    else:
        items = obj
    if not isinstance(items, list):
        raise ValueError("prediction must be a list or {'pieces': [...]}")

    # 先按 name 收集
    by_name = {}
    for p in items:
        if not isinstance(p, dict):
            continue
        nm = p.get("name")
        if nm in ROSTER and nm not in by_name:
            by_name[nm] = deepcopy(p)

    pieces = []
    for nm, (ty, sz) in ROSTER.items():
        if nm not in by_name:
            raise ValueError(f"missing piece {nm}")
        p = by_name[nm]
        # 读取/转换字段
        pos = p.get("pos", [0, 0])
        angle = _closest_angle(p.get("angle", 0))
        flip = bool(p.get("flip", False))
        scale = float(p.get("scale", 1.0))
        try:
            pos = [float(pos[0]), float(pos[1])]
        except Exception:
            pos = [0.0, 0.0]
        pieces.append({
            "name": nm,
            "type": ty,
            "size": sz,
            "pos": pos,
            "angle": angle,
            "flip": flip,
            "scale": scale,
        })
    return pieces

# ----------------- Few-shot run -----------------
def run_one(model, processor, fewshot, test_outline_png, max_new=900, teach=True):
    # 中文注释：构造对话消息列表，包含few-shot示例和彩色图及对应的JSON，最后加入测试图像和指令
    messages=[]
    # few-shot (colored image + GT JSON printed by assistant)
    for img_path, json_str in fewshot:
        messages.append({
            "role":"user",
            "content":[{"type":"image","image":img_path},
                       {"type":"text","text":INSTR+"\nSolved example (JSON below):"}]
        })
        messages.append({"role":"assistant","content":[{"type":"text","text":json_str}]})

    final_prompt = INSTR + ("\n"+CHEATSHEET if teach else "") + "\nNow solve this outline and return JSON."
    messages.append({
        "role":"user",
        "content":[{"type":"image","image":test_outline_png},{"type":"text","text":final_prompt}]
    })

    chat = processor.apply_chat_template(messages, tokenize=False, add_generation_prompt=True)
    imgs, vids = process_vision_info(messages)
    inputs = processor(text=[chat], images=imgs, videos=vids, return_tensors="pt", padding=True)
    inputs = {k:(v.to("cpu") if isinstance(v,torch.Tensor) else v) for k,v in inputs.items()}

    with torch.no_grad():
        out_ids = model.generate(**inputs, max_new_tokens=max_new, do_sample=True, temperature=0.35, top_p=0.9)
    trimmed=[o[len(i):] for i,o in zip(inputs["input_ids"],out_ids)]
    raw = processor.batch_decode(trimmed, skip_special_tokens=True)[0].strip()
    js = strict_find_json(raw)
    if not js: return None, raw
    try:
        obj = json.loads(js)
        pieces = sanitize_pred(obj)
        return pieces, raw
    except Exception as e:
        return None, raw

# ----------------- Main -----------------
def list_ids_from_dir(img_dir):
    return sorted([os.path.splitext(fn)[0] for fn in os.listdir(img_dir) if fn.endswith(".png")])

def main():
    ap=argparse.ArgumentParser()
    ap.add_argument("--params_dir", required=True)
    ap.add_argument("--outlines_dir", required=True)
    ap.add_argument("--input_dir", required=True)      # params_vis (colored)
    ap.add_argument("--n_train", type=int, default=20) # few-shot pool size
    ap.add_argument("--n_test",  type=int, default=15)
    ap.add_argument("--seed", type=int, default=42)
    ap.add_argument("--save_csv", default="tangram_eval.csv")
    ap.add_argument("--save_png", default="tangram_eval.png")
    ap.add_argument("--teach", action="store_true")
    args=ap.parse_args()

    random.seed(args.seed)
    # 中文注释：筛选出同时拥有彩色图、轮廓图和JSON的ID集合
    all_colored = set(list_ids_from_dir(args.input_dir))
    all_outlines= set(list_ids_from_dir(args.outlines_dir))
    all_json    = set([os.path.splitext(fn)[0] for fn in os.listdir(args.params_dir) if fn.endswith(".json")])
    ids = sorted(list(all_colored & all_outlines & all_json))
    if len(ids) < args.n_train + args.n_test:
        raise RuntimeError(f"Not enough IDs. Have {len(ids)}, need {args.n_train+args.n_test}.")

    random.shuffle(ids)
    train_ids = ids[:args.n_train]
    test_ids  = ids[args.n_train:args.n_train+args.n_test]

    os.makedirs(osp.dirname(args.save_csv), exist_ok=True)
    os.makedirs(osp.dirname(args.save_png), exist_ok=True)
    viz_dir = osp.join(osp.dirname(args.save_png), "viz")
    os.makedirs(viz_dir, exist_ok=True)
    preds_dir = osp.dirname(args.save_csv)  # 复用CSV所在目录保存每个测试样本的预测JSON与原始输出

    # 中文注释：打印训练集和测试集ID，训练集包含彩色图和GT JSON，测试集仅包含轮廓图，GT JSON用于评估
    print(f"[INFO] 选定训练集（彩色图+GT JSON）ID数量: {len(train_ids)}，示例: {train_ids[:5]}{' ...' if len(train_ids)>5 else ''}")
    print(f"[INFO] 选定测试集（仅轮廓图输入，GT JSON用于评估）ID数量: {len(test_ids)}，示例: {test_ids}")

    # 中文注释：准备few-shot示例列表，每个元素为(彩色图路径, GT JSON字符串)
    fewshot=[]
    for _id in train_ids:
        img_path = os.path.join(args.input_dir, f"{_id}.png")
        gt_json  = os.path.join(args.params_dir, f"{_id}.json")
        must_exist(img_path); must_exist(gt_json)
        fewshot.append((img_path, load_json_str(gt_json)))

    print(f"[INFO] 构造few-shot示例，数量: {len(fewshot)}")

    print("[INFO] 正在加载 Qwen2.5-VL 模型（仅加载一次）...")
    model = Qwen2_5_VLForConditionalGeneration.from_pretrained("Qwen/Qwen2.5-VL-3B-Instruct", device_map="cpu").eval()
    processor = AutoProcessor.from_pretrained("Qwen/Qwen2.5-VL-3B-Instruct")

    rows=[]
    for tid in test_ids:
        outline_png = os.path.join(args.outlines_dir, f"{tid}.png")
        must_exist(outline_png)
        # 中文注释：调用模型进行推理，传入few-shot示例和测试轮廓图，获取预测JSON和原始文本
        pred, raw = run_one(model, processor, fewshot, outline_png, max_new=900, teach=args.teach)
        if pred is None:
            print(f"[WARN] {tid}: 未解析出有效JSON，IoU设为0.0")
            rows.append((tid, 0.0, 0.0, 1.0))
            continue

        # 使用与 geometry.py 完全一致的加载与渲染
        gt = load_silhouette(outline_png, size=CANVAS_SIZE)
        pred_mask = render_mask(pred, CANVAS_SIZE, CANVAS_SIZE)
        iou, ov, un = geometry_metrics(pred_mask, gt)
        print(f"[EVAL] ID={tid} 几何评估结果 IoU={iou:.3f} / overflow={ov:.4f} / undercov={un:.4f}")

        # 保存本次测试样本的预测JSON（严格7件套）与模型原始输出文本，便于复查/复现
        pred_json_path = osp.join(preds_dir, f"{tid}.pred.json")
        with open(pred_json_path, "w", encoding="utf-8") as f:
            json.dump({"pieces": pred}, f, ensure_ascii=False, indent=2)
        pred_raw_path = osp.join(preds_dir, f"{tid}.raw.txt")
        with open(pred_raw_path, "w", encoding="utf-8") as f:
            f.write(raw)
        print(f"[SAVE] 预测JSON与原始输出已保存: {pred_json_path} | {pred_raw_path}")

        # 保存可视化：绿色=GT轮廓，红色=预测，黄色=重叠
        overlay = np.zeros((CANVAS_SIZE, CANVAS_SIZE, 3), dtype=np.uint8)
        overlay[...,1] = gt * 255                # G 通道放 GT
        overlay[...,0] = pred_mask * 255         # R 通道放 Pred
        both = (gt & pred_mask).astype(np.uint8) * 255
        overlay[...,0] = np.maximum(overlay[...,0], both)
        overlay[...,1] = np.maximum(overlay[...,1], both)
        Image.fromarray(overlay).save(osp.join(viz_dir, f"{tid}_overlay.png"))

        # 叠加到原始轮廓 PNG 的可视化（坐标系与 IoU 完全一致）
        save_overlay_on_outline(outline_png, pred, osp.join(viz_dir, f"{tid}_overlay_on_outline.png"))

        rows.append((tid, iou, ov, un))

    # 中文注释：保存评估结果到CSV文件
    with open(args.save_csv, "w", newline="", encoding="utf-8") as f:
        cw = csv.writer(f)
        cw.writerow(["id","IoU","overflow","undercov"])
        cw.writerows(rows)
    print("[SAVE] 评估结果已保存至CSV文件:", args.save_csv)

    # 中文注释：绘制测试集IoU柱状图，横轴为ID，纵轴为IoU
    ids=[r[0] for r in rows]; ious=[r[1] for r in rows]
    plt.figure(figsize=(max(8, 0.5*len(ids)), 4))
    plt.bar(range(len(ids)), ious)
    plt.xticks(range(len(ids)), ids, rotation=45, ha="right")
    plt.ylabel("IoU")
    plt.title(f"Tangram few-shot eval (train={len(train_ids)}, test={len(test_ids)})")
    plt.tight_layout()
    plt.savefig(args.save_png, dpi=160)
    avg_iou = float(np.mean(ious)) if ious else 0.0
    print(f"[SUMMARY] 测试集平均 IoU = {avg_iou:.3f}")
    print("[SAVE] IoU柱状图已保存至PNG文件:", args.save_png)

if __name__ == "__main__":
    main()