#!/usr/bin/env python3
# -*- coding: utf-8 -*-
"""
polys_to_params.py
读取 svg_to_polys.py 产出的 7 多边形 JSON（统一到 [0,10]×[0,10]），
为每块拟合 geometry 模板的 {type, pos, angle, flip, scale}，输出到 params 目录。

依赖：numpy
可选：shapely（用于 IoU 评估）
"""

import argparse, json, os, math
from pathlib import Path
import numpy as np

# ------------------------
# 1) 七巧板模板（与 geometry.py 保持一致）
# ------------------------
SQRT2 = math.sqrt(2.0)

TEMPLATES = {
    "big_triangle":        np.array([[0,0],[4*SQRT2,0],[0,4*SQRT2]], float),
    "medium_triangle":     np.array([[0,0],[4,0],[0,4]], float),
    "small_triangle":      np.array([[0,0],[2*SQRT2,0],[0,2*SQRT2]], float),
    "square":              np.array([[0,0],[2*SQRT2,0],[2*SQRT2,2*SQRT2],[0,2*SQRT2]], float),
    "parallelogram":       np.array([[0,0],[4,0],[6,2],[2,2]], float),
}

TYPES_ORDER = ["big_triangle","big_triangle","medium_triangle","small_triangle","small_triangle","square","parallelogram"]

# ------------------------
# 2) 实用函数
# ------------------------
def polygon_area(pts):
    x = pts[:,0]; y = pts[:,1]
    return 0.5*abs(np.dot(x, np.roll(y,-1)) - np.dot(y, np.roll(x,-1)))

def kabsch_similarity(U, V):
    """求 s,R,t 使 s R U + t ≈ V（点已一一对应）"""
    Uc = U - U.mean(axis=0, keepdims=True)
    Vc = V - V.mean(axis=0, keepdims=True)
    H = Uc.T @ Vc
    U_svd, S, Vt = np.linalg.svd(H)
    R = Vt.T @ U_svd.T
    if np.linalg.det(R) < 0:  # 防反射
        Vt[1,:] *= -1
        R = Vt.T @ U_svd.T
    s = (S.sum()) / (Uc**2).sum()
    t = V.mean(axis=0) - s*(R @ U.mean(axis=0))
    pred = (s*(R @ U.T)).T + t
    mse = float(((pred - V)**2).mean())
    ang = math.degrees(math.atan2(R[1,0], R[0,0]))
    return s, ang, t, mse

def all_cyclic_orders(pts):
    """生成循环移位和反向的所有顺序（保证凸多边形也能匹配上顶点次序）"""
    n = len(pts)
    orders = []
    idx = list(range(n))
    for k in range(n):
        orders.append(np.array(np.roll(idx, -k)))
    ridx = list(reversed(idx))
    for k in range(n):
        orders.append(np.array(np.roll(ridx, -k)))
    return orders

def try_fit(template, target):
    """同时尝试 flip=False/True；并在顶点顺序上枚举，取最小误差"""
    best = None
    for flip in (False, True):
        U = template.copy()
        if flip:
            U = U.copy()
            U[:,0] *= -1  # x 镜像
        for ord_idx in all_cyclic_orders(target):
            V = target[ord_idx]
            if V.shape[0] != U.shape[0]: 
                continue
            s, ang, t, mse = kabsch_similarity(U, V)
            cand = dict(scale=float(s), angle=float(ang), flip=bool(flip), pos=V.mean(axis=0).tolist(), mse=mse)
            if (best is None) or (cand["mse"] < best["mse"]):
                best = cand
    return best

def snap_angle(a_deg, typ):
    g = 90 if typ == "square" else 45
    return float(round(a_deg / g) * g)

def classify_types(polys):
    """基于面积的简单分类：2大三角=最大2名；2小三角=最小2名；剩余中三角+正方形+平行四边形"""
    areas = [polygon_area(p) for p in polys]
    order = np.argsort(areas)  # 升序
    small_idx = order[:2].tolist()
    big_idx   = order[-2:].tolist()
    rest = [i for i in range(7) if i not in small_idx + big_idx]
    # 在剩余3个里找 square（角接近 90° 且边接近相等），其余中三角+平行四边形
    def is_square(P, tol=0.1):
        # 边长接近相等 + 内角接近90°
        vecs = np.roll(P, -1, axis=0) - P
        lens = np.linalg.norm(vecs, axis=1)
        if (lens.max() - lens.min())/max(lens.max(),1e-6) > tol:
            return False
        v1 = vecs; v2 = np.roll(vecs, -1, axis=0)
        cos = (v1*v2).sum(axis=1)/(np.linalg.norm(v1,axis=1)*np.linalg.norm(v2,axis=1)+1e-9)
        return np.all(np.abs(cos) < 0.2)
    sq_idx = None
    for i in rest:
        if polys[i].shape[0] == 4 and is_square(polys[i]):
            sq_idx = i; break
    rest2 = [i for i in rest if i != sq_idx]
    # 在 rest2 中，顶点数=3 的是 medium_triangle，其余=平行四边形
    mid_idx = next((i for i in rest2 if polys[i].shape[0]==3), None)
    para_idx= next((i for i in rest2 if polys[i].shape[0]==4 and i!=sq_idx), None)
    mapping = {}
    for i in big_idx:   mapping[i] = "big_triangle"
    for i in small_idx: mapping[i] = "small_triangle"
    if mid_idx is not None: mapping[mid_idx] = "medium_triangle"
    if sq_idx  is not None: mapping[sq_idx]  = "square"
    if para_idx is not None:mapping[para_idx]= "parallelogram"
    # 兜底：若有缺漏，就按 TYPES_ORDER 填满
    missing = [i for i in range(7) if i not in mapping]
    for t in TYPES_ORDER:
        if t not in mapping.values() and missing:
            mapping[missing.pop(0)] = t
    return mapping  # index -> type

def add_instances(params_list):
    """为 big/small 三角加 instance 标签，保证和 geometry.py 的映射一致"""
    # 规则：按质心 (x,y) 排序赋值，保证稳定
    bigs = [(i,p) for i,p in enumerate(params_list) if p["type"]=="big_triangle"]
    smalls = [(i,p) for i,p in enumerate(params_list) if p["type"]=="small_triangle"]
    bigs.sort(key=lambda x: (x[1]["pos"][0], x[1]["pos"][1]))
    smalls.sort(key=lambda x: (x[1]["pos"][0], x[1]["pos"][1]))
    if len(bigs)==2:
        params_list[bigs[0][0]]["instance"]="L1"
        params_list[bigs[1][0]]["instance"]="L2"
    if len(smalls)==2:
        params_list[smalls[0][0]]["instance"]="S1"
        params_list[smalls[1][0]]["instance"]="S2"

# ------------------------
# 3) 主流程
# ------------------------
def main():
    ap = argparse.ArgumentParser()
    ap.add_argument("--in_dir", required=True, help="svg_to_polys.py 的输出目录")
    ap.add_argument("--out_dir", required=True, help="拟合后的 params 输出目录")
    ap.add_argument("--snap", action="store_true", help="角度吸附到 45°/90°")
    ap.add_argument("--min_iou", type=float, default=0.0, help="若安装 shapely，可计算联合 IoU 并筛选")
    args = ap.parse_args()

    IN = Path(args.in_dir)
    OUT = Path(args.out_dir); OUT.mkdir(parents=True, exist_ok=True)

    # 可选 IoU
    try:
        from shapely.geometry import Polygon as SP
        from shapely.ops import unary_union
        HAS_SHAPELY = True
    except Exception:
        HAS_SHAPELY = False

    kept = 0
    for jp in sorted(IN.glob("*.json")):
        polys = json.loads(jp.read_text(encoding="utf-8"))
        polys = [np.array(item["polygon"], float) for item in polys]
        if len(polys)!=7: 
            continue

        # 分类
        idx2type = classify_types(polys)

        # 拟合
        params = []
        for i,P in enumerate(polys):
            typ = idx2type[i]
            temp = TEMPLATES[typ]
            fit = try_fit(temp, P)
            if args.snap:
                fit["angle"] = snap_angle(fit["angle"], typ)
            params.append({
                "type": typ,
                "pos": [float(fit["pos"][0]), float(fit["pos"][1])],
                "angle": float(fit["angle"]),
                "flip": bool(fit["flip"]),
                "scale": float(fit["scale"]),
                "mse_fit": float(fit["mse"])
            })
        add_instances(params)

        # 可选 IoU（把7块 union 和原7块 union 的 IoU；这里直接用目标自身计算，主要用于阈值筛）
        if HAS_SHAPELY and args.min_iou > 0:
            # 拟合后的多边形（用模板反变换得到）
            fitted = []
            for p in params:
                U = TEMPLATES[p["type"]].copy()
                if p.get("flip", False):
                    U[:,0] *= -1
                a = math.radians(p["angle"])
                R = np.array([[math.cos(a), -math.sin(a)],[math.sin(a), math.cos(a)]])
                U = (U @ R.T) * p["scale"] + np.array(p["pos"])
                fitted.append(SP(U))
            union_fit = unary_union(fitted)
            union_src = unary_union([SP(P) for P in polys])
            inter = union_fit.intersection(union_src).area
            union = union_fit.union(union_src).area
            iou = inter / (union + 1e-9)
        else:
            iou = None

        # 筛选（可选）
        if iou is not None and iou < args.min_iou:
            continue

        # 写出
        outp = OUT / f"{jp.stem}.json"
        outp.write_text(json.dumps(params, ensure_ascii=False, indent=2), encoding="utf-8")
        kept += 1

    print(f"[DONE] wrote {kept} param file(s) to: {OUT}")

if __name__ == "__main__":
    main()