#!/usr/bin/env python3
"""
Load the MADQA dataset from local storage.

The dataset is included in this supplementary material under ./data/

Usage:
    from datasets import load_from_disk
    
    # Load all splits
    dataset = load_from_disk("./data")
    
    # Access specific splits
    train = dataset["train"]
    dev = dataset["dev"]
    test = dataset["test"]
"""

import argparse
from pathlib import Path

from datasets import load_from_disk


def load_dataset_local(data_dir: str = "./data"):
    """
    Load the dataset from local storage.
    
    Args:
        data_dir: Path to the saved dataset directory
        
    Returns:
        DatasetDict with train/dev/test splits
    """
    data_path = Path(data_dir)
    if not data_path.exists():
        raise FileNotFoundError(f"Dataset not found at {data_path}")
    
    dataset = load_from_disk(str(data_path))
    return dataset


def main():
    parser = argparse.ArgumentParser(
        description="Load and explore the benchmark dataset"
    )
    parser.add_argument(
        "--data-dir", "-d",
        default="./data",
        help="Path to the dataset directory"
    )
    parser.add_argument(
        "--split",
        choices=["train", "dev", "test"],
        help="Specific split to show (default: show all)"
    )
    parser.add_argument(
        "--sample", "-n",
        type=int,
        default=3,
        help="Number of samples to display"
    )
    
    args = parser.parse_args()
    
    # Load dataset
    print(f"Loading dataset from {args.data_dir}...")
    dataset = load_dataset_local(args.data_dir)
    
    print("\nDataset structure:")
    print(dataset)
    
    if args.split:
        splits = [args.split]
    else:
        splits = list(dataset.keys())
    
    for split in splits:
        print(f"\n{'='*60}")
        print(f"{split.upper()} split ({len(dataset[split])} examples)")
        print('='*60)
        
        for i, example in enumerate(dataset[split]):
            if i >= args.sample:
                break
            print(f"\n[{i}] Question: {example['question'][:100]}...")
            print(f"    Category: {example['document_category']}")
            print(f"    Domain: {example['domain']}")
            
            if example.get('answer_variants'):
                print(f"    Answers: {example['answer_variants'][0][:2]}...")
            else:
                print("    Answers: [non-disclosed]")
            
            if example.get('evidence'):
                print(f"    Evidence: {example['evidence']}")
            else:
                print("    Evidence: [non-disclosed]")


if __name__ == "__main__":
    main()
