#!/usr/bin/env python3
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the BSD-style license found in the
# LICENSE file in the root directory of this source tree.

"""
Download a small evaluation cache from a large HF dataset by sampling across many shards.

This script uses torchtitan.datasets.hf_datasets.create_single_shard_dataset_for_repo
to enumerate file shards on the Hub, then collects ~3-4 samples per shard across
800 shards (or as configured) until a target sample count is reached. Samples are
saved as JSONL (gzip) with a single 'text' field, plus a small MANIFEST.json.

Example:
  HF_TOKEN=xxxx python download_eval_samples.py \
    --repo-id HuggingFaceFW/fineweb-edu --split train \
    --num-shards 800 --total-samples 2500 \
    --output-dir ./eval_cache

You can later load this with Hugging Face Datasets streaming:
  from datasets import load_dataset
  ds = load_dataset("json", data_files={"train": "./eval_cache/*.jsonl.gz"}, streaming=True)["train"]

Or add a dataset entry in torchtitan's mapping that loads the above with the 'json' builder.
"""

from __future__ import annotations

import argparse
import gzip
import json
import math
import os
import sys
from datetime import datetime, timezone
from typing import Any, Optional

from torchtitan.datasets.hf_datasets import create_single_shard_dataset_for_repo


def parse_args() -> argparse.Namespace:
    parser = argparse.ArgumentParser(
        description="Sample a small eval set across many shards from a HF dataset repo",
        formatter_class=argparse.ArgumentDefaultsHelpFormatter,
    )
    parser.add_argument(
        "--repo-id",
        type=str,
        required=True,
        help="HF dataset repo id, e.g. 'HuggingFaceFW/fineweb-edu'",
    )
    parser.add_argument(
        "--split",
        type=str,
        default="train",
        choices=["train", "validation"],
        help="Split name to prefer when filtering files by filename hints",
    )
    parser.add_argument(
        "--num-shards",
        type=int,
        default=800,
        help="Total virtual shards to partition the repo file list into",
    )
    parser.add_argument(
        "--total-samples",
        type=int,
        default=2500,
        help="Total number of text samples to collect",
    )
    parser.add_argument(
        "--seed",
        type=int,
        default=42,
        help="Seed used for stable file ordering before sharding",
    )
    parser.add_argument(
        "--max-files-per-shard",
        type=int,
        default=100,
        help="Upper bound of files per shard to open (limits breadth per shard)",
    )
    parser.add_argument(
        "--output-dir",
        type=str,
        default="./eval_cache",
        help="Directory to write the JSONL.GZ and manifest",
    )
    parser.add_argument(
        "--outfile",
        type=str,
        default=None,
        help="Optional explicit output filename (defaults to derived name)",
    )
    parser.add_argument(
        "--hf-token",
        type=str,
        default=None,
        help="HF token (falls back to HF_TOKEN env var)",
    )
    return parser.parse_args()


def _extract_text(sample: dict[str, Any]) -> Optional[str]:
    # Prefer the common 'text' field; keep strict to ensure compatibility
    if not isinstance(sample, dict):
        return None
    text = sample.get("text")
    if isinstance(text, str):
        stripped = text.strip()
        return stripped if len(stripped) >= 10 else None
    return None


def _ensure_dir(path: str) -> None:
    os.makedirs(path, exist_ok=True)


def _default_outfile(
    output_dir: str, repo_id: str, split: str, total: int, seed: int
) -> str:
    safe_repo = repo_id.replace("/", "__")
    return os.path.join(
        output_dir, f"{safe_repo}__{split}__n{total}__seed{seed}.jsonl.gz"
    )


def main() -> int:
    args = parse_args()

    hf_token = args.hf_token or os.getenv("HF_TOKEN")
    if hf_token is None:
        # Not strictly required for public datasets, but warn the user
        print(
            "[warn] HF_TOKEN not set; proceeding unauthenticated (may be slower or blocked)"
        )

    _ensure_dir(args.output_dir)
    outfile = args.outfile or _default_outfile(
        args.output_dir, args.repo_id, args.split, args.total_samples, args.seed
    )

    per_shard_budget = max(1, math.ceil(args.total_samples / max(1, args.num_shards)))

    # Minimal manifest for provenance
    manifest = {
        "repo_id": args.repo_id,
        "split": args.split,
        "num_shards": int(args.num_shards),
        "seed": int(args.seed),
        "max_files_per_shard": int(args.max_files_per_shard),
        "total_samples": int(args.total_samples),
        "created_utc": datetime.now(tz=timezone.utc).isoformat(),
        "outfile": os.path.abspath(outfile),
    }
    with open(
        os.path.join(args.output_dir, "MANIFEST.json"), "w", encoding="utf-8"
    ) as f:
        json.dump(manifest, f, indent=2)

    num_written = 0
    seen_text_hashes: set[int] = set()

    # Open the gzip writer once; write incrementally
    with gzip.open(outfile, mode="wt", encoding="utf-8") as gzout:
        for shard_index in range(args.num_shards):
            if num_written >= args.total_samples:
                break
            try:
                ds = create_single_shard_dataset_for_repo(
                    repo_id=args.repo_id,
                    shard_index=shard_index,
                    num_shards=args.num_shards,
                    token=hf_token,
                    split=args.split,
                    seed=args.seed,
                    max_files_per_shard=args.max_files_per_shard,
                )
            except Exception as e:
                print(
                    f"[warn] failed to create shard {shard_index}/{args.num_shards}: {e}"
                )
                continue

            taken_from_this_shard = 0
            for sample in ds:
                if num_written >= args.total_samples:
                    break
                try:
                    text = _extract_text(sample)
                    if not text:
                        continue
                    text_hash = hash(text)
                    if text_hash in seen_text_hashes:
                        continue
                    seen_text_hashes.add(text_hash)
                    gzout.write(json.dumps({"text": text}, ensure_ascii=False))
                    gzout.write("\n")
                    num_written += 1
                    taken_from_this_shard += 1
                    if taken_from_this_shard >= per_shard_budget:
                        break
                except Exception as e:
                    # Skip any problematic records and continue streaming
                    print(f"[warn] skipping a bad sample from shard {shard_index}: {e}")
                    continue

    print(
        json.dumps(
            {
                "status": "ok",
                "written": num_written,
                "outfile": os.path.abspath(outfile),
                "manifest": os.path.abspath(
                    os.path.join(args.output_dir, "MANIFEST.json")
                ),
            },
            indent=2,
        )
    )

    if num_written < args.total_samples:
        print(
            f"[note] Wrote {num_written} < requested {args.total_samples}. Some shards may have had few valid samples.",
            file=sys.stderr,
        )
    return 0


if __name__ == "__main__":
    raise SystemExit(main())
