#!/usr/bin/env python3
# -*- coding: utf-8 -*-
"""
make_outlines_from_full.py
从 full.json 生成“正确的 outline（黑形白底）”PNG，只保留方形画布的题目。

依赖：pip install pillow numpy

用法示例：
python make_outlines_from_full.py \
  --input /Users/zongyikun/Desktop/qwen_vl_demo/kilogram-main/dataset/full.json \
  --out_dir /Users/zongyikun/Desktop/qwen_vl_demo/kilogram-main/dataset/outlines \
  --flip-y --square-only --square-tol 0.02 --snap-angles
"""
import json, math, argparse, sys
from pathlib import Path
import numpy as np
from PIL import Image, ImageDraw

SQRT2 = math.sqrt(2.0)
# 与 geometry.py 一致的模板
TEMPLATE = {
    "big_triangle":      np.array([[0,0],[4*SQRT2,0],[0,4*SQRT2]], dtype=float),
    "medium_triangle":   np.array([[0,0],[4,0],[0,4]], dtype=float),
    "small_triangle":    np.array([[0,0],[2*SQRT2,0],[0,2*SQRT2]], dtype=float),
    "square":            np.array([[0,0],[2*SQRT2,0],[2*SQRT2,2*SQRT2],[0,2*SQRT2]], dtype=float),
    "parallelogram":     np.array([[0,0],[4,0],[6,2],[2,2]], dtype=float),
}

def rot(deg: float) -> np.ndarray:
    a = math.radians(deg); c,s = math.cos(a), math.sin(a)
    return np.array([[c,-s],[s,c]], float)

def apply(poly, pos, angle, flip, scale):
    P = poly.copy()
    if flip:
        P[:,0] *= -1
    P *= float(scale)
    P = (P @ rot(angle).T) + np.array(pos, float)
    return P

def snap_angle(a, typ):
    return round(a/90.0)*90.0 if typ=="square" else round(a/45.0)*45.0

def to_int_pts(P, size):
    # 将[0,10]×[0,10]映射到 size×size 像素坐标（y 轴向下）
    x = (P[:,0] / 10.0) * (size-1)
    y = (1.0 - P[:,1]/10.0) * (size-1)
    return list(map(tuple, np.stack([x,y],1).round().astype(int)))

def _yield_from_container(obj):
    """Yield (id, pieces) from various container shapes."""
    # Case: the entire object is directly the 7-piece list
    if isinstance(obj, list) and obj and isinstance(obj[0], dict) and len(obj) == 7 and all('type' in d for d in obj):
        yield ("sample#0", obj)
        return

    if isinstance(obj, dict):
        for k, v in obj.items():
            pcs = None
            if isinstance(v, dict):
                if "pieces" in v and isinstance(v["pieces"], list):
                    pcs = v["pieces"]
                elif "annotations" in v and isinstance(v["annotations"], list):
                    a0 = v["annotations"][0] if v["annotations"] else None
                    if isinstance(a0, list):
                        pcs = a0
                    elif isinstance(a0, dict) and "pieces" in a0:
                        pcs = a0["pieces"]
            elif isinstance(v, list) and len(v) == 7 and v and isinstance(v[0], dict) and 'type' in v[0]:
                pcs = v
            if pcs:
                yield (k, pcs)
    elif isinstance(obj, list):
        for i, v in enumerate(obj):
            pcs = None
            if isinstance(v, dict):
                if "pieces" in v and isinstance(v["pieces"], list):
                    pcs = v["pieces"]
                elif "annotations" in v and isinstance(v["annotations"], list):
                    a0 = v["annotations"][0] if v["annotations"] else None
                    if isinstance(a0, list):
                        pcs = a0
                    elif isinstance(a0, dict) and "pieces" in a0:
                        pcs = a0["pieces"]
            elif isinstance(v, list) and len(v) == 7 and v and isinstance(v[0], dict) and 'type' in v[0]:
                pcs = v
            if pcs:
                kid = v.get("id", f"sample#{i}") if isinstance(v, dict) else f"sample#{i}"
                yield (kid, pcs)


def extract_pieces(obj_or_path):
    """General entry: obj_or_path can be a Path to a file/dir, or a parsed JSON object.
    Yields (id, pieces).
    """
    # If it's a directory, iterate each .json file as one sample
    if isinstance(obj_or_path, (str, Path)):
        p = Path(obj_or_path)
        if p.is_dir():
            for fp in sorted(p.glob('*.json')):
                try:
                    data = json.loads(fp.read_text(encoding='utf-8'))
                except Exception as e:
                    print(f"[WARN] skip {fp.name}: {e}", file=sys.stderr)
                    continue
                # If the file itself is a list of 7 pieces, use filename stem as id
                yielded = False
                for sid, pcs in _yield_from_container(data):
                    yield (fp.stem if sid == 'sample#0' else sid, pcs)
                    yielded = True
                if not yielded and isinstance(data, list) and len(data) == 7 and all(isinstance(d, dict) for d in data):
                    yield (fp.stem, data)
            return
        elif p.is_file():
            data = json.loads(p.read_text(encoding='utf-8'))
            for pair in _yield_from_container(data):
                yield pair
            return

    # Already a parsed object
    for pair in _yield_from_container(obj_or_path):
        yield pair

def reconstruct_polys(pieces):
    polys=[]
    for p in pieces:
        typ  = p.get("type")
        base = TEMPLATE.get(typ)
        if base is None: continue
        pos  = p.get("pos") or p.get("position") or [0,0]
        ang  = float(p.get("angle", 0.0))
        flip = bool(p.get("flip", False))
        sca  = float(p.get("scale", 1.0))
        polys.append(apply(base, pos, ang, flip, sca))
    return polys

def bbox(polys):
    pts = np.vstack(polys)
    return pts.min(0), pts.max(0)

def main():
    ap = argparse.ArgumentParser()
    ap.add_argument("--input", required=True, help="full.json 路径")
    ap.add_argument("--out_dir", required=True, help="输出 PNG 目录")
    ap.add_argument("--size", type=int, default=1024, help="输出 PNG 边长像素（默认1024）")
    ap.add_argument("--target", type=float, default=10.0, help="统一画布尺寸（默认 10）")
    ap.add_argument("--flip-y", action="store_true", help="把坐标上下翻转到屏幕坐标")
    ap.add_argument("--square-only", action="store_true", help="只保留近似正方形的样本")
    ap.add_argument("--square-tol", type=float, default=0.02, help="方形宽高比容差（±，默认0.02）")
    ap.add_argument("--snap-angles", action="store_true", help="角度离散化（square 90°，其他45°）")
    ap.add_argument("--verbose", action="store_true", help="显示详细信息")
    args = ap.parse_args()

    out_dir = Path(args.out_dir); out_dir.mkdir(parents=True, exist_ok=True)

    kept = 0
    total = 0
    square_pass = 0
    for sid, pcs in extract_pieces(args.input):
        total += 1
        # 1) 先在原坐标系下重建，判断是否近似正方形（可选）
        polys_raw = reconstruct_polys(pcs)
        if not polys_raw: continue
        mn, mx = bbox(polys_raw)
        span = np.maximum(mx - mn, 1e-9)
        r = float(span[0]/span[1])
        is_square = (1-args.square_tol <= r <= 1+args.square_tol)
        if args.square_only and not is_square:
            if '--verbose' in sys.argv:
                print(f"[SKIP not-square] {sid} r={r:.4f}")
            continue
        if is_square:
            square_pass += 1

        # 2) 归一化到 [0,target]×[0,target]，并可选 flip-y，角度可选离散
        s_canvas = args.target / float(max(span))
        pcs_norm = []
        for p in pcs:
            q = dict(p)
            pos = q.get("pos") or q.get("position") or [0,0]
            x = (float(pos[0]) - float(mn[0])) * s_canvas
            y = (float(pos[1]) - float(mn[1])) * s_canvas
            if args.flip_y:
                y = args.target - y
            q["pos"] = [x, y]
            if args.snap_angles:
                q["angle"] = float(snap_angle(float(q.get("angle",0.0)), q.get("type","")))
            pcs_norm.append(q)

        # 3) 用“尺寸直通”的方式渲染（按 full.json 的 scale）
        #    若要把尺寸改成标准模板，可在这里把 q['scale'] = s_canvas / k 之类（视模板定义）；
        #    但现在我们保持直通。
        img = Image.new("L", (args.size, args.size), 255)  # 白底
        drw = ImageDraw.Draw(img)
        for q in pcs_norm:
            base = TEMPLATE.get(q["type"])
            if base is None: continue
            P = apply(base, q["pos"], float(q.get("angle",0.0)), bool(q.get("flip",False)), float(q.get("scale",1.0)))
            pts = to_int_pts(P, args.size)
            drw.polygon(pts, fill=0)

        out_path = out_dir / f"{sid}.png"
        img.save(out_path)
        kept += 1

    print(f"[OK] Wrote {kept} outline(s) to: {out_dir}  | scanned={total}, square_pass={square_pass}")

if __name__ == "__main__":
    main()