"""Download and prepare LongBench v2 dataset from HuggingFace."""

import argparse
from pathlib import Path
import pickle
import json

from datasets import load_dataset
from transformers import AutoTokenizer
from tqdm import tqdm


def download_longbench_v2(
    output_dir: str,
    tokenizer_name: str = 'meta-llama/Llama-3.1-8B-Instruct',
    difficulty: str = None,
    length: str = None,
    domain: str = None,
    max_samples: int = None,
):
    """Download and prepare LongBench v2 dataset.

    Args:
        output_dir: Output directory for processed data
        tokenizer_name: HuggingFace tokenizer name
        difficulty: Filter by difficulty ('easy' or 'hard'), None for all
        length: Filter by length ('short', 'medium', 'long'), None for all
        domain: Filter by domain, None for all
        max_samples: Maximum number of samples to process, None for all
    """
    output_path = Path(output_dir)
    output_path.mkdir(parents=True, exist_ok=True)

    # Load tokenizer
    print(f"Loading tokenizer: {tokenizer_name}")
    tokenizer = AutoTokenizer.from_pretrained(tokenizer_name)

    # Get token IDs for A, B, C, D (useful for logit-based evaluation)
    answer_token_ids = {
        'A': tokenizer.encode('A', add_special_tokens=False)[0],
        'B': tokenizer.encode('B', add_special_tokens=False)[0],
        'C': tokenizer.encode('C', add_special_tokens=False)[0],
        'D': tokenizer.encode('D', add_special_tokens=False)[0],
    }
    print(f"Answer token IDs: {answer_token_ids}")

    # Load dataset
    print(f"Loading LongBench v2 dataset from HuggingFace...")
    dataset = load_dataset('THUDM/LongBench-v2', split='train')

    print(f"Total samples in dataset: {len(dataset)}")

    # Apply filters
    filtered_data = []
    stats = {
        'difficulty': {},
        'length': {},
        'domain': {},
        'context_lengths': [],
    }

    for example in tqdm(dataset, desc="Processing and filtering"):
        # Apply filters
        if difficulty and example['difficulty'] != difficulty:
            continue
        if length and example['length'] != length:
            continue
        if domain and example['domain'] != domain:
            continue

        # Track statistics
        stats['difficulty'][example['difficulty']] = stats['difficulty'].get(example['difficulty'], 0) + 1
        stats['length'][example['length']] = stats['length'].get(example['length'], 0) + 1
        stats['domain'][example['domain']] = stats['domain'].get(example['domain'], 0) + 1

        # Tokenize to get length
        context_tokens = tokenizer.encode(example['context'], add_special_tokens=False)
        stats['context_lengths'].append(len(context_tokens))

        # Store processed example
        processed_example = {
            '_id': example['_id'],
            'domain': example['domain'],
            'sub_domain': example['sub_domain'],
            'difficulty': example['difficulty'],
            'length': example['length'],
            'question': example['question'],
            'choice_A': example['choice_A'],
            'choice_B': example['choice_B'],
            'choice_C': example['choice_C'],
            'choice_D': example['choice_D'],
            'answer': example['answer'],
            'context': example['context'],
            'context_length_tokens': len(context_tokens),
        }

        filtered_data.append(processed_example)

        if max_samples and len(filtered_data) >= max_samples:
            break

    # Print statistics
    print(f"\n{'='*60}")
    print(f"Dataset Statistics (after filtering):")
    print(f"{'='*60}")
    print(f"Total samples: {len(filtered_data)}")
    print(f"\nBy Difficulty:")
    for diff, count in sorted(stats['difficulty'].items()):
        print(f"  {diff}: {count}")
    print(f"\nBy Length:")
    for len_cat, count in sorted(stats['length'].items()):
        print(f"  {len_cat}: {count}")
    print(f"\nBy Domain:")
    for dom, count in sorted(stats['domain'].items()):
        print(f"  {dom}: {count}")

    if stats['context_lengths']:
        import numpy as np
        lengths = np.array(stats['context_lengths'])
        print(f"\nContext Length Statistics (in tokens):")
        print(f"  Min: {lengths.min():,}")
        print(f"  Max: {lengths.max():,}")
        print(f"  Mean: {lengths.mean():,.1f}")
        print(f"  Median: {np.median(lengths):,.1f}")
        print(f"  25th percentile: {np.percentile(lengths, 25):,.1f}")
        print(f"  75th percentile: {np.percentile(lengths, 75):,.1f}")

    # Save processed data
    suffix = ""
    if difficulty:
        suffix += f"_diff_{difficulty}"
    if length:
        suffix += f"_len_{length}"
    if domain:
        suffix += f"_dom_{domain.replace(' ', '_')}"
    if max_samples:
        suffix += f"_n_{max_samples}"

    # Save as pickle
    output_file = output_path / f"longbench_v2{suffix}.pkl"
    print(f"\nSaving to {output_file}")

    with open(output_file, 'wb') as f:
        pickle.dump({
            'samples': filtered_data,
            'tokenizer_name': tokenizer_name,
            'answer_token_ids': answer_token_ids,
            'statistics': stats,
        }, f)

    # Also save as JSON for easy inspection
    json_file = output_path / f"longbench_v2{suffix}.json"
    print(f"Saving JSON to {json_file}")

    with open(json_file, 'w') as f:
        json.dump({
            'samples': filtered_data,
            'tokenizer_name': tokenizer_name,
            'answer_token_ids': answer_token_ids,
            'statistics': {k: v for k, v in stats.items() if k != 'context_lengths'},
        }, f, indent=2)

    print("\nDone!")


def main():
    parser = argparse.ArgumentParser(description='Download LongBench v2 dataset')
    parser.add_argument(
        '--output_dir',
        type=str,
        default='data/longbench_v2',
        help='Output directory',
    )
    parser.add_argument(
        '--tokenizer',
        type=str,
        default='meta-llama/Llama-3.1-8B-Instruct',
        help='Tokenizer name',
    )
    parser.add_argument(
        '--difficulty',
        type=str,
        choices=['easy', 'hard'],
        default=None,
        help='Filter by difficulty',
    )
    parser.add_argument(
        '--length',
        type=str,
        choices=['short', 'medium', 'long'],
        default=None,
        help='Filter by context length category',
    )
    parser.add_argument(
        '--domain',
        type=str,
        default=None,
        help='Filter by domain (e.g., "Single-Doc QA", "Multi-Doc QA", etc.)',
    )
    parser.add_argument(
        '--max_samples',
        type=int,
        default=None,
        help='Maximum number of samples to process',
    )

    args = parser.parse_args()

    download_longbench_v2(
        output_dir=args.output_dir,
        tokenizer_name=args.tokenizer,
        difficulty=args.difficulty,
        length=args.length,
        domain=args.domain,
        max_samples=args.max_samples,
    )


if __name__ == '__main__':
    main()
