#!/usr/bin/env python3
# -*- coding: utf-8 -*-
"""
svg_to_polys.py
从 /tangrams-svg 解析每个 SVG 的 7 个拼块多边形，统一到 [0,10]×[0,10] 坐标。
支持 <polygon>；若遇到 <path> 则用 svgpathtools 采样离散。

依赖：pip install lxml numpy svgpathtools
"""
import argparse, json
from pathlib import Path
import numpy as np
from lxml import etree
import sys
import gzip
from io import BytesIO

# 可选：采样 PATH
def path_to_points(d: str, n_samples: int = 128):
    from svgpathtools import parse_path
    P = parse_path(d)
    L = P.length()
    if L == 0:
        return np.zeros((0,2))
    ts = np.linspace(0, 1, n_samples, endpoint=False)
    pts = [P.point(t) for t in ts]
    arr = np.column_stack([np.real(pts), np.imag(pts)]).astype(float)
    return arr

def parse_points_attr(s: str):
    s = s.replace(",", " ")
    toks = [t for t in s.strip().split() if t]
    pts = []
    it = iter(toks)
    for x in it:
        y = next(it, None)
        if y is None: break
        pts.append([float(x), float(y)])
    return np.array(pts, float)

def read_svg_polys(svg_path: Path):
    # Support .svg and .svgz
    if str(svg_path).lower().endswith(".svgz"):
        with gzip.open(str(svg_path), "rb") as f:
            data = f.read()
        root = etree.fromstring(data)
    else:
        root = etree.parse(str(svg_path)).getroot()
    svg = root
    ns = svg.nsmap.get(None, "")
    q = f"{{{ns}}}" if ns else ""
    polys_by_id = {}
    by_order = []
    for el in svg.findall(f".//{q}polygon"):
        pid = el.get("id") or el.get("{http://www.w3.org/1999/xlink}id")
        pts = parse_points_attr(el.get("points",""))
        if pts.shape[0] >= 3:
            by_order.append(pts)
            if pid is not None:
                polys_by_id[pid] = pts
    for el in svg.findall(f".//{q}path"):
        pid = el.get("id") or el.get("{http://www.w3.org/1999/xlink}id")
        d   = el.get("d","")
        if not d: continue
        try:
            pts = path_to_points(d, n_samples=256)
        except Exception:
            continue
        if pts.shape[0] >= 3:
            by_order.append(pts)
            if pid is not None:
                polys_by_id[pid] = pts
    ordered = []
    if all(str(i) in polys_by_id for i in range(1,8)):
        ordered = [polys_by_id[str(i)] for i in range(1,8)]
    elif len(by_order) >= 7:
        ordered = by_order[:7]
    else:
        ordered = by_order
    return ordered

def bbox(polys):
    pts = np.vstack(polys)
    return pts.min(0), pts.max(0)

def unify(polys, target=10.0, flip_y=True):
    mn, mx = bbox(polys)
    span = np.maximum(mx - mn, 1e-9)
    s = target / float(max(span))
    out=[]
    for P in polys:
        Q = (P - mn) * s
        if flip_y:
            Q[:,1] = target - Q[:,1]
        out.append(Q)
    return out, s, mn, mx

def is_square(polys, tol=0.02):
    mn, mx = bbox(polys)
    w, h = float(mx[0]-mn[0]), float(mx[1]-mn[1])
    r = w / max(h, 1e-9)
    return (1-tol) <= r <= (1+tol), r

def main():
    ap = argparse.ArgumentParser()
    ap.add_argument("--in_dir", required=True)
    ap.add_argument("--out_dir", required=True)
    ap.add_argument("--target", type=float, default=10.0)
    ap.add_argument("--flip-y", action="store_true")
    ap.add_argument("--square-only", action="store_true")
    ap.add_argument("--square-tol", type=float, default=0.02)
    ap.add_argument("--recursive", action="store_true")
    ap.add_argument("--verbose", action="store_true")
    args = ap.parse_args()

    in_dir = Path(args.in_dir)
    out_dir = Path(args.out_dir); out_dir.mkdir(parents=True, exist_ok=True)

    if args.recursive:
        files = (
            list(in_dir.rglob("*.svg"))
            + list(in_dir.rglob("*.SVG"))
            + list(in_dir.rglob("*.svgz"))
            + list(in_dir.rglob("*.SVGZ"))
        )
    else:
        files = (
            list(in_dir.glob("*.svg"))
            + list(in_dir.glob("*.SVG"))
            + list(in_dir.glob("*.svgz"))
            + list(in_dir.glob("*.SVGZ"))
        )

    if args.verbose:
        print(f"[INFO] Found {len(files)} file(s) under {in_dir}", file=sys.stderr)
        if len(files) > 0:
            # Show a couple for sanity
            for p in sorted(files)[:3]:
                print(f"       - {p}", file=sys.stderr)

    if args.verbose and not files:
        print(f"[WARN] No SVG files found under {in_dir}", file=sys.stderr)

    kept=scanned=sq=0
    for svg in sorted(files):
        polys = read_svg_polys(svg)
        scanned+=1
        if len(polys) < 7:
            if args.verbose:
                print(f"[SKIP <7 shapes] {svg.name} got={len(polys)}", file=sys.stderr)
            continue
        ok, r = is_square(polys, args.square_tol)
        if args.square_only and not ok:
            continue
        if ok:
            sq+=1
        polys_u, s, mn, mx = unify(polys, args.target, args.flip_y)
        data=[{"id":i+1,"polygon":P.round(6).tolist()} for i,P in enumerate(polys_u)]
        (out_dir/f"{svg.stem}.json").write_text(
            json.dumps(data, ensure_ascii=False, indent=2), encoding="utf-8"
        )
        kept+=1
        if args.verbose:
            print(f"[KEEP] {svg.stem} r={r:.4f}")

    print(f"[DONE] kept={kept}, scanned={scanned}, square_like={sq}, out={out_dir}")

if __name__ == "__main__":
    main()