#!/usr/bin/env python3
"""Build a HF-ready tabular dataset using the same generator as generate_offline_tabular.py.

Key properties
- Uses TabularSampler (MLP-SCM prior) with optional TabICL-style x normalization and y z-score.
- Emits hierarchical JSONL grouped by combos: <output_root>/train/nc={Nc}_d={D}/part-XXXXX.jsonl
- Train-only by default (no validation/test). Controls examples per combo via --per_combo.
- Rotates shards with --shard_size to keep files manageable (similar to our HF tooling).

Each record:
  {id, combo, combo_id, nc, d, x: [[...]*D]*Nc, y: [..]*Nc}
"""

from __future__ import annotations

import argparse
import json
from pathlib import Path
from typing import Dict, Tuple, List

import numpy as np
import torch
import json

try:
    import pyarrow as pa
    import pyarrow.parquet as pq
    PA_AVAILABLE = True
except Exception:
    PA_AVAILABLE = False

from src.data.tabular_sampler import TabularSampler


def parse_int_list(csv: str) -> list[int]:
    return [int(s) for s in csv.split(",") if s]


def dtype_from_str(name: str) -> torch.dtype:
    return {
        "float16": torch.float16,
        "bfloat16": torch.bfloat16,
        "float32": torch.float32,
        "float64": torch.float64,
    }[name]


def open_writer(base_dir: Path):
    base_dir.mkdir(parents=True, exist_ok=True)
    shard_idx = 0
    path = base_dir / f"part-{shard_idx:05d}.jsonl"
    f = path.open("w", encoding="utf-8")
    return f, path, shard_idx, 0  # handle, path, shard_idx, count_in_file


def main():
    ap = argparse.ArgumentParser(description="Build HF JSONL using TabICL/MLP-SCM sampler (train-only)")
    ap.add_argument("--output_root", required=True, help="Output dataset root folder")
    ap.add_argument("--seed", type=int, default=123)
    ap.add_argument("--nc_list", default="8,16,32,64,128,256,512,1024")
    ap.add_argument("--d_list", default="1,2,3,4,5,6,7,8,9,10")
    ap.add_argument("--per_combo", type=int, default=32, help="Examples per (nc,d) combo")
    ap.add_argument("--shard_size", type=int, default=5000, help="Max records per JSONL shard")
    ap.add_argument("--format", default="jsonl", choices=["jsonl", "parquet"], help="Output file format per shard")
    # Normalization and dtype/device (mirror generate_offline_tabular.py)
    ap.add_argument("--normalize_x", action="store_true", default=False)
    ap.add_argument("--x_norm_method", default="power", choices=["power", "quantile", "quantile_rtdl", "none"])
    ap.add_argument("--x_outlier_threshold", type=float, default=4.0)
    ap.add_argument("--normalize_y", action="store_true", default=True)
    ap.add_argument("--dtype", default="float32", choices=["float16", "bfloat16", "float32", "float64"])
    ap.add_argument("--device", default="cpu")
    # Fixed buffer/target sizes (included in each record). Use 32/512 by default.
    ap.add_argument("--num_buffer", type=int, default=32, help="Fixed buffer size (Nb) per sample")
    ap.add_argument("--num_target", type=int, default=512, help="Fixed target size (Nt) per sample")
    args = ap.parse_args()

    # Seeds for reproducibility
    torch.manual_seed(args.seed)
    np.random.seed(args.seed)

    out_root = Path(args.output_root)
    out_root.mkdir(parents=True, exist_ok=True)

    nc_list = parse_int_list(args.nc_list)
    d_list = parse_int_list(args.d_list)

    manifest = {
        "combos": {},  # "nc=128_d=5" -> {id, count, nc, d}
        "next_combo_id": 0,
        "splits": {"train": 0},
        "per_combo": args.per_combo,
    }

    # Writers/Buffers per (split, combo)
    # - For JSONL: (file, path, shard_idx, count_in_file)
    # - For Parquet: (rows_buffer: List[dict], path, shard_idx, count_in_file)
    writers_jsonl: Dict[Tuple[str, str], Tuple[object, Path, int, int]] = {}
    buffers_parquet: Dict[Tuple[str, str], Tuple[List[dict], Path, int, int]] = {}

    def get_writer_jsonl(split: str, combo: str):
        key = (split, combo)
        base_dir = out_root / split / combo
        if key not in writers_jsonl:
            writers_jsonl[key] = open_writer(base_dir)
        f, path, shard_idx, count_in_file = writers_jsonl[key]
        if count_in_file >= args.shard_size:
            # rotate shard
            f.close()
            shard_idx += 1
            path = base_dir / f"part-{shard_idx:05d}.jsonl"
            f = path.open("w", encoding="utf-8")
            count_in_file = 0
            writers_jsonl[key] = (f, path, shard_idx, count_in_file)
        return writers_jsonl[key]

    def make_schema_parquet():
        return pa.schema([
            ("id", pa.int64()),
            ("combo", pa.string()),
            ("combo_id", pa.int64()),
            ("nc", pa.int64()),
            ("d", pa.int64()),
            ("nb", pa.int64()),
            ("nt", pa.int64()),
            ("x", pa.list_(pa.list_(pa.float32()))),
            ("y", pa.list_(pa.float32())),
            ("xb", pa.list_(pa.list_(pa.float32()))),
            ("yb", pa.list_(pa.float32())),
            ("xt", pa.list_(pa.list_(pa.float32()))),
            ("yt", pa.list_(pa.float32())),
        ])

    def flush_parquet(split: str, combo: str, rows: List[dict], shard_idx: int):
        assert PA_AVAILABLE, "pyarrow is required for --format parquet"
        base_dir = out_root / split / combo
        base_dir.mkdir(parents=True, exist_ok=True)
        path = base_dir / f"part-{shard_idx:05d}.parquet"
        schema = make_schema_parquet()
        # Build columns in schema order
        cols = {name: [] for name in schema.names}
        for r in rows:
            for name in schema.names:
                cols[name].append(r.get(name))
        arrays = [pa.array(cols[name], type=schema.field(name).type) for name in schema.names]
        table = pa.Table.from_arrays(arrays, schema=schema)
        pq.write_table(table, path)

    def get_buffer_parquet(split: str, combo: str):
        key = (split, combo)
        base_dir = out_root / split / combo
        if key not in buffers_parquet:
            buffers_parquet[key] = ([], base_dir, 0, 0)  # rows, base_dir, shard_idx, count_in_file
        return buffers_parquet[key]

    # Loop combos and generate per_combo examples
    sample_id = 0
    for nc in nc_list:
        for d in d_list:
            combo = f"nc={nc}_d={d}"
            cid = manifest["next_combo_id"]
            manifest["combos"][combo] = {"id": cid, "count": 0, "nc": int(nc), "d": int(d)}
            manifest["next_combo_id"] += 1

            # Tabular sampler configured for fixed D and desired normalization
            sampler = TabularSampler(
                dim_x=[int(d)],
                dim_y=1,
                is_causal=True,
                num_causes=None,
                num_layers=4,
                hidden_dim=64,
                noise_std=0.01,
                sampling="mixed",
                normalize_y=bool(args.normalize_y),
                normalize_x=bool(args.normalize_x),
                x_norm_method=args.x_norm_method,
                x_outlier_threshold=float(args.x_outlier_threshold),
                device=str(args.device),
                dtype=dtype_from_str(args.dtype),
            )

            needed = int(args.per_combo)
            MAX_GEN_BATCH = 1024  # internal generation chunking to limit peak memory
            while needed > 0:
                B = min(MAX_GEN_BATCH, needed)
                batch = sampler.generate_batch(
                    batch_size=B,
                    num_context=int(nc),
                    num_buffer=int(args.num_buffer),
                    num_target=int(args.num_target),
                    context_range=None,
                )
                # Move to CPU numpy
                xc, yc = batch.xc.cpu().numpy(), batch.yc.cpu().numpy()      # [B,Nc,D], [B,Nc,1]
                xb, yb = batch.xb.cpu().numpy(), batch.yb.cpu().numpy()      # [B,Nb,D], [B,Nb,1]
                xt, yt = batch.xt.cpu().numpy(), batch.yt.cpu().numpy()      # [B,Nt,D], [B,Nt,1]

                # Write records
                for i in range(B):
                    rec = {
                        "id": int(sample_id),
                        "combo": combo,
                        "combo_id": int(cid),
                        "nc": int(nc),
                        "d": int(d),
                        "nb": int(args.num_buffer),
                        "nt": int(args.num_target),
                        "x": xc[i].tolist(),
                        "y": yc[i].reshape(int(nc)).tolist(),
                        "xb": xb[i].tolist(),
                        "yb": yb[i].reshape(int(args.num_buffer)).tolist(),
                        "xt": xt[i].tolist(),
                        "yt": yt[i].reshape(int(args.num_target)).tolist(),
                    }
                    if args.format == "jsonl":
                        f, path, shard_idx, count_in_file = get_writer_jsonl("train", combo)
                        f.write(json.dumps(rec) + "\n")
                        writers_jsonl[("train", combo)] = (f, path, shard_idx, count_in_file + 1)
                    else:  # parquet
                        rows, base_dir, shard_idx, count_in_file = get_buffer_parquet("train", combo)
                        rows.append(rec)
                        count_in_file += 1
                        if count_in_file >= args.shard_size:
                            flush_parquet("train", combo, rows, shard_idx)
                            rows = []
                            shard_idx += 1
                            count_in_file = 0
                        buffers_parquet[("train", combo)] = (rows, base_dir, shard_idx, count_in_file)
                    manifest["combos"][combo]["count"] += 1
                    manifest["splits"]["train"] += 1
                    sample_id += 1
                needed -= B

    # Close/flush writers and save manifest
    if args.format == "jsonl":
        for (_, _), (f, _, _, _) in writers_jsonl.items():
            f.close()
    else:
        for (split, combo), (rows, _, shard_idx, count_in_file) in buffers_parquet.items():
            if rows:
                flush_parquet(split, combo, rows, shard_idx)

    with (out_root / "manifest.json").open("w", encoding="utf-8") as mf:
        json.dump(manifest, mf, indent=2)

    print(f"Wrote dataset to {out_root}; total samples: {sample_id}")
    print("Example layout:")
    print("  ", out_root / "train" / "nc=128_d=5" / "part-00000.jsonl")


if __name__ == "__main__":
    main()
