import os
import json
import argparse
import itertools
import random
from typing import List, Dict, Any

# ================================
# Defaults (absolute paths to avoid CWD issues)
# ================================
DEF_IN_DIR = "/Users/zongyikun/Desktop/qwen_vl_demo/kilogram-main/dataset/params_from_svg"
DEF_OUT_SINGLE = "/Users/zongyikun/Desktop/qwen_vl_demo/kilogram-main/dataset/onepiece_from_svg"
DEF_OUT_PAIR = "/Users/zongyikun/Desktop/qwen_vl_demo/kilogram-main/dataset/two_pieces_from_svg"

BLACKLIST = {
    "page-G",
    "page1-6",
    "page3-188",
    "page5-150",
    "page5-210",
    "page7-119",
    "page7-162",
}

# ================================
# Helpers
# ================================

def extract_pieces(obj: Any) -> List[Dict[str, Any]]:
    """Extract the list of pieces from different JSON structures.
    Supports:
      1) top-level list
      2) data["pieces"]
      3) data["params"]
    Returns empty list if not found.
    """
    if isinstance(obj, list):
        return obj
    if isinstance(obj, dict):
        if "pieces" in obj and isinstance(obj["pieces"], list):
            return obj["pieces"]
        if "params" in obj and isinstance(obj["params"], list):
            return obj["params"]
    return []


def ensure_dir(path: str) -> None:
    os.makedirs(path, exist_ok=True)


def safe_type_name(piece: Dict[str, Any], fallback: str) -> str:
    t = None
    if isinstance(piece, dict):
        t = piece.get("type") or piece.get("name")
    if not t:
        t = fallback
    # make filename-safe
    return str(t).replace("/", "-").replace(" ", "_")


# ================================
# Single-piece dataset (keeps original behavior)
# ================================

def build_single_piece_dataset(in_dir: str, out_dir: str) -> None:
    ensure_dir(out_dir)
    total_files = 0
    total_pieces = 0

    for fname in sorted(os.listdir(in_dir)):
        if not fname.endswith(".json"):
            continue
        path = os.path.join(in_dir, fname)
        try:
            with open(path, "r") as f:
                data = json.load(f)
        except Exception as e:
            print(f"[WARN] 读取失败: {path} -> {e}")
            continue

        base = os.path.splitext(fname)[0]
        if base in BLACKLIST:
            print(f"[SKIP] 黑名单页面，跳过: {fname}")
            continue

        pieces = extract_pieces(data)
        if not pieces:
            print(f"[SKIP] 未找到可拆分的 pieces: {fname}")
            continue

        for i, piece in enumerate(pieces, start=1):
            ptype = safe_type_name(piece, f"piece{i}")
            out_path = os.path.join(out_dir, f"{base}_piece{i}_{ptype}.json")
            try:
                with open(out_path, "w", encoding="utf-8") as fo:
                    json.dump(piece, fo, indent=2, ensure_ascii=False)
            except Exception as e:
                print(f"[WARN] 写入失败: {out_path} -> {e}")
                continue
            total_pieces += 1
        total_files += 1
        print(f"[OK] {fname}: 拆分 {len(pieces)} 个 -> {out_dir}")

    print(f"\n✅ Done (single). 处理文件数: {total_files}, 生成单块 JSON 数: {total_pieces}")


# ================================
# Two-piece dataset (NEW)
# ================================

def build_two_piece_dataset(
    in_dir: str,
    out_dir: str,
    pairs_per_page: int = 10,
    strategy: str = "random",
    seed: int = 42,
) -> None:
    """
    Build a dataset where each JSON contains exactly two pieces from the same page.

    Output JSON schema:
    {
      "page": "page4-180",           # base name without .json
      "source_file": "page4-180.json",
      "piece_indices": [i, j],         # 1-based indices in the original list
      "pieces": [piece_i, piece_j]     # the two piece dicts
    }

    Filenames: {base}_pair{k}_{name_i}_{name_j}.json
    where name_i / name_j are derived from piece type or fallback to piece{i}.
    """
    ensure_dir(out_dir)
    rng = random.Random(seed)

    total_files = 0
    total_pairs = 0

    for fname in sorted(os.listdir(in_dir)):
        if not fname.endswith(".json"):
            continue
        path = os.path.join(in_dir, fname)
        try:
            with open(path, "r", encoding="utf-8") as f:
                data = json.load(f)
        except Exception as e:
            print(f"[WARN] 读取失败: {path} -> {e}")
            continue

        base = os.path.splitext(fname)[0]
        if base in BLACKLIST:
            print(f"[SKIP] 黑名单页面，跳过: {fname}")
            continue

        pieces = extract_pieces(data)
        n = len(pieces)
        if n < 2:
            print(f"[SKIP] 少于2块: {fname}")
            continue

        # All unique index pairs (1-based for metadata)
        all_pairs = [(i+1, j+1) for i, j in itertools.combinations(range(n), 2)]

        if strategy == "all":
            chosen = all_pairs
        elif strategy == "consecutive":
            chosen = [(i, i+1) for i in range(1, n)]
        else:  # random (default)
            if pairs_per_page >= len(all_pairs):
                chosen = all_pairs
            else:
                chosen = rng.sample(all_pairs, k=pairs_per_page)

        k = 0
        for i1, i2 in chosen:
            p1 = pieces[i1 - 1]
            p2 = pieces[i2 - 1]
            name1 = safe_type_name(p1, f"piece{i1}")
            name2 = safe_type_name(p2, f"piece{i2}")
            k += 1

            out_obj = {
                "page": base,
                "source_file": fname,
                "piece_indices": [i1, i2],
                "pieces": [p1, p2],
            }

            out_name = f"{base}_pair{k}_{name1}_{name2}.json"
            out_path = os.path.join(out_dir, out_name)
            try:
                with open(out_path, "w", encoding="utf-8") as fo:
                    json.dump(out_obj, fo, indent=2, ensure_ascii=False)
            except Exception as e:
                print(f"[WARN] 写入失败: {out_path} -> {e}")
                continue

            total_pairs += 1
        total_files += 1
        print(f"[OK] {fname}: 生成 {len(chosen)} 个两块组合 -> {out_dir}")

    print(f"\n✅ Done (two-pieces). 处理文件数: {total_files}, 生成两块组合数: {total_pairs}")


# ================================
# CLI
# ================================

def parse_args() -> argparse.Namespace:
    p = argparse.ArgumentParser(
        description="Build single-piece or two-piece Tangram datasets from page-level params JSONs."
    )
    p.add_argument("--mode", choices=["single", "pair"], default="single",
                   help="single: 输出单块数据集; pair: 输出两块组合数据集")
    p.add_argument("--in_dir", default=DEF_IN_DIR,
                   help="输入目录（page*.json，含所有 pieces）")
    p.add_argument("--out_single", default=DEF_OUT_SINGLE,
                   help="单块输出目录 (mode=single)")
    p.add_argument("--out_pair", default=DEF_OUT_PAIR,
                   help="两块组合输出目录 (mode=pair)")
    # pair-specific
    p.add_argument("--pairs_per_page", type=int, default=10,
                   help="每个 page 随机采样的两块组合数量（strategy=random 时生效）")
    p.add_argument("--strategy", choices=["random", "consecutive", "all"], default="random",
                   help="pair 采样策略：random/随机，consecutive/相邻，all/所有组合")
    p.add_argument("--seed", type=int, default=42, help="随机种子")
    return p.parse_args()


def main():
    args = parse_args()

    if args.mode == "single":
        build_single_piece_dataset(args.in_dir, args.out_single)
    else:
        build_two_piece_dataset(
            in_dir=args.in_dir,
            out_dir=args.out_pair,
            pairs_per_page=args.pairs_per_page,
            strategy=args.strategy,
            seed=args.seed,
        )


if __name__ == "__main__":
    main()