#!/usr/bin/env python3
"""Iterate batches from a Hugging Face JSONL tabular dataset using existing utilities.

- Loads split via datasets with hf:// paths
- Uses UniformComboBatchSampler to ensure shape-homogeneous batches by (nc,d)
- Uses collate_stack_same_shape to stack dict samples into tensors

"""

from __future__ import annotations

import argparse
from typing import Tuple

import numpy as np
from datasets import load_dataset
from torch.utils.data import DataLoader, Dataset

from src.data.bucketed_sampler import UniformComboBatchSampler
from src.data.collate_strict import collate_stack_same_shape, collate_full_same_shape


class HFDataset(Dataset):
    """Thin wrapper to make datasets.Dataset compatible with torch DataLoader typing."""

    def __init__(self, ds):
        self.ds = ds

    def __len__(self) -> int:  # type: ignore[override]
        return len(self.ds)

    def __getitem__(self, idx: int):  # type: ignore[override]
        return self.ds[int(idx)]


def load_split(repo_id: str, split: str, *, force: bool = False, cache_dir: str | None = None, file_format: str = "auto"):
    """Load a split from the Hub using hf:// URL patterns.

    Set force=True to bypass datasets cache (force redownload).
    """
    # Choose extension and builder
    builder = "json"
    pattern = f"hf://datasets/{repo_id}/{split}/**/part-*.jsonl"
    if file_format == "parquet":
        builder = "parquet"
        pattern = f"hf://datasets/{repo_id}/{split}/**/part-*.parquet"
    kwargs = {"streaming": False}
    if force:
        # Force re-download and avoid split-size verification using stale dataset_info
        kwargs["download_mode"] = "force_redownload"
        kwargs["verification_mode"] = "no_checks"
    if cache_dir:
        kwargs["cache_dir"] = cache_dir
    try:
        return load_dataset(builder, data_files={split: pattern}, split=split, **kwargs)
    except Exception:
        if file_format == "auto":
            # Try the other format
            other_builder = "parquet" if builder == "json" else "json"
            other_pattern = pattern.replace(".jsonl", ".parquet") if builder == "json" else pattern.replace(".parquet", ".jsonl")
            return load_dataset(other_builder, data_files={split: other_pattern}, split=split, **kwargs)
        raise


def resolve_combo_ids(ds) -> Tuple[np.ndarray, int]:
    """Return (combo_ids in [0..K-1], K). Uses existing combo_id if present; else maps combo strings."""
    if "combo_id" in ds.column_names:
        raw = np.asarray(ds["combo_id"], dtype=np.int64)
        uniques = np.unique(raw)
        # Reindex to 0..K-1 in case ids are not contiguous
        remap = {int(u): i for i, u in enumerate(sorted(int(x) for x in uniques))}
        combo_ids = np.fromiter((remap[int(v)] for v in raw), count=raw.size, dtype=np.int64)
        K = len(uniques)
        return combo_ids.astype(np.int32), K
    else:
        combos = ds["combo"]  # list[str]
        uniq = sorted(set(combos))
        remap = {c: i for i, c in enumerate(uniq)}
        arr = np.fromiter((remap[c] for c in combos), count=len(combos), dtype=np.int64)
        return arr.astype(np.int32), len(uniq)


def main() -> None:
    ap = argparse.ArgumentParser(description="Iterate a few batches from HF tabular dataset")
    ap.add_argument("--repo-id", required=True, help="HF dataset repo id (e.g., user/name)")
    ap.add_argument("--split", default="train", choices=["train", "validation", "test"], help="Dataset split")
    ap.add_argument("--batch-size", type=int, default=32, help="Items per batch (same (nc,d))")
    ap.add_argument("--num-workers", type=int, default=0, help="PyTorch DataLoader workers")
    ap.add_argument("--shuffle", action="store_true", help="Shuffle within each (nc,d) bucket")
    ap.add_argument("--keep-last", action="store_true", help="Keep last partial batches (drop_last=False)")
    ap.add_argument("--max-batches", type=int, default=5, help="Number of batches to print and then stop")
    ap.add_argument("--force", action="store_true", help="Force redownload from Hub (bypass datasets cache)")
    ap.add_argument("--cache-dir", default=None, help="Custom datasets cache dir to avoid stale dataset_info")
    ap.add_argument("--mode", default="auto", choices=["auto", "context", "full"], help="Collate mode: auto-detect or force context/full")
    ap.add_argument("--file-format", default="auto", choices=["auto", "json", "parquet"], help="Shard file format on the Hub")
    args = ap.parse_args()

    ds = load_split(args.repo_id, args.split, force=args.force, cache_dir=args.cache_dir, file_format=args.file_format)
    combo_ids, K = resolve_combo_ids(ds)

    sampler = UniformComboBatchSampler(
        combo_ids=combo_ids,
        num_combos=K,
        batch_size=args.batch_size,
        shuffle=args.shuffle,
        drop_last=not args.keep_last,
    )

    # Choose collate based on mode or presence of keys
    cols = set(ds.column_names)
    want_full = False
    if args.mode == "full":
        want_full = True
    elif args.mode == "context":
        want_full = False
    else:  # auto
        want_full = {"xb", "yb", "xt", "yt"}.issubset(cols)

    collate_fn = collate_full_same_shape if want_full else collate_stack_same_shape

    dl = DataLoader(
        HFDataset(ds),
        batch_sampler=sampler,
        num_workers=args.num_workers,
        collate_fn=collate_fn,
    )

    print(f"Loaded split='{args.split}' from {args.repo_id}")
    print(f"  total samples: {len(ds):,}; combos: {K}; batch_size: {args.batch_size}")
    printed = 0
    for batch in dl:
        x, y = batch["x"], batch["y"]  # x:[B,Nc,D], y:[B,Nc,1]
        B, Nc, D = x.shape
        if "xb" in batch:
            xb, yb = batch["xb"], batch["yb"]
            xt, yt = batch["xt"], batch["yt"]
            Nb, Nt = xb.shape[1], xt.shape[1]
            print(f"Batch {printed+1}: x{tuple(x.shape)} y{tuple(y.shape)} xb{tuple(xb.shape)} yb{tuple(yb.shape)} xt{tuple(xt.shape)} yt{tuple(yt.shape)} (Nc={Nc}, D={D}, Nb={Nb}, Nt={Nt})")
        else:
            print(f"Batch {printed+1}: x{tuple(x.shape)} y{tuple(y.shape)}  (Nc={Nc}, D={D})")
        printed += 1
        if printed >= args.max_batches:
            break


if __name__ == "__main__":
    main()
