#!/usr/bin/env python3
"""
Shard a dataset into smaller batches for distributed processing.
Saves each batch in HuggingFace's Arrow format using save_to_disk().
"""

import argparse
import os
import json
from datasets import load_dataset, load_from_disk


def main(args):
    print(f"Loading dataset from: {args.dataset_path}")
    
    if os.path.exists(args.dataset_path) and os.path.isdir(args.dataset_path):
        if os.path.exists(os.path.join(args.dataset_path, "dataset_info.json")):
            print("Loading from local directory (saved with save_to_disk)...")
            dataset = load_from_disk(args.dataset_path)
        else:
            print(f"Loading from local directory (split: {args.dataset_split})...")
            dataset = load_dataset(args.dataset_path, split=args.dataset_split, cache_dir=args.hf_cache)
    else:
        print(f"Loading from HuggingFace Hub or file (split: {args.dataset_split})...")
        dataset = load_dataset(args.dataset_path, split=args.dataset_split, cache_dir=args.hf_cache)
    
    total_samples = len(dataset)
    num_batches = (total_samples + args.batch_size - 1) // args.batch_size
    
    print(f"Dataset size: {total_samples} samples")
    print(f"Batch size: {args.batch_size}")
    print(f"Number of batches: {num_batches}")
    
    os.makedirs(args.output_dir, exist_ok=True)
    
    metadata = {
        "dataset_path": args.dataset_path,
        "dataset_split": args.dataset_split,
        "total_samples": total_samples,
        "batch_size": args.batch_size,
        "num_batches": num_batches,
        "columns": list(dataset.features.keys())
    }
    
    metadata_path = os.path.join(args.output_dir, "metadata.json")
    with open(metadata_path, "w") as f:
        json.dump(metadata, f, indent=2)
    print(f"Saved metadata to: {metadata_path}")
    

    print("\nSharding dataset into batches...")
    for i in range(num_batches):
        start_idx = i * args.batch_size
        end_idx = min(start_idx + args.batch_size, total_samples)
        

        batch = dataset.select(range(start_idx, end_idx))
        

        batch_dir = os.path.join(args.output_dir, f"batch_{i:04d}")
        batch.save_to_disk(batch_dir)
        
        print(f"Saved batch {i}: samples {start_idx}-{end_idx-1} ({len(batch)} samples) to {batch_dir}")
    
    print(f"\nSharding complete! All batches saved to: {args.output_dir}")


if __name__ == "__main__":
    parser = argparse.ArgumentParser(description="Shard a dataset into smaller batches")
    parser.add_argument("--dataset-path", type=str, required=True,
                        help="Dataset path (HuggingFace Hub name or local directory)")
    parser.add_argument("--dataset-split", type=str, default="train",
                        help="Dataset split to use (only for HF Hub datasets)")
    parser.add_argument("--output-dir", type=str, required=True,
                        help="Directory to save sharded batches")
    parser.add_argument("--batch-size", type=int, default=100,
                        help="Number of samples per batch")
    parser.add_argument("--hf-cache", type=str, default=None,
                        help="HuggingFace cache directory")
    
    args = parser.parse_args()
    main(args)