"""Analyze LongBench v2 length categories to determine exact token boundaries."""

import argparse
import numpy as np
from datasets import load_dataset
from transformers import AutoTokenizer
from tqdm import tqdm


def analyze_length_categories(tokenizer_name: str):
    """Analyze the token length boundaries for short/medium/long categories.

    Args:
        tokenizer_name: HuggingFace tokenizer to use for tokenization
    """
    print(f"Loading LongBench v2 from HuggingFace...")
    dataset = load_dataset('THUDM/LongBench-v2', split='train')
    print(f"Loaded {len(dataset)} samples")

    print(f"\nLoading tokenizer: {tokenizer_name}")
    tokenizer = AutoTokenizer.from_pretrained(tokenizer_name)

    print("\nTokenizing contexts...")
    samples = []
    for example in tqdm(dataset):
        context_tokens = tokenizer.encode(example['context'], add_special_tokens=False)
        samples.append({
            'length': example['length'],
            'context_length_tokens': len(context_tokens),
        })
    print()

    # Group by length category
    short = [s['context_length_tokens'] for s in samples if s['length'] == 'short']
    medium = [s['context_length_tokens'] for s in samples if s['length'] == 'medium']
    long = [s['context_length_tokens'] for s in samples if s['length'] == 'long']

    print("="*70)
    print("LENGTH CATEGORY BOUNDARIES")
    print("="*70)

    print(f"\nSHORT ({len(short)} samples):")
    print(f"  Range: {np.min(short):,} - {np.max(short):,} tokens")
    print(f"  Mean: {np.mean(short):,.1f} tokens")
    print(f"  Median: {np.median(short):,.1f} tokens")

    print(f"\nMEDIUM ({len(medium)} samples):")
    print(f"  Range: {np.min(medium):,} - {np.max(medium):,} tokens")
    print(f"  Mean: {np.mean(medium):,.1f} tokens")
    print(f"  Median: {np.median(medium):,.1f} tokens")

    print(f"\nLONG ({len(long)} samples):")
    print(f"  Range: {np.min(long):,} - {np.max(long):,} tokens")
    print(f"  Mean: {np.mean(long):,.1f} tokens")
    print(f"  Median: {np.median(long):,.1f} tokens")

    print("\n" + "="*70)
    print("SUMMARY")
    print("="*70)
    print(f"Short:  < {np.max(short):,} tokens")
    print(f"Medium: {np.min(medium):,} - {np.max(medium):,} tokens")
    print(f"Long:   > {np.min(long):,} tokens")

    # Check for gaps or overlaps
    if np.max(short) >= np.min(medium):
        print(f"\n⚠️  WARNING: Overlap between short and medium!")
    if np.max(medium) >= np.min(long):
        print(f"\n⚠️  WARNING: Overlap between medium and long!")

    # Distribution by powers of 2
    print("\n" + "="*70)
    print("DISTRIBUTION BY POWERS OF 2")
    print("="*70)

    all_lengths = [s['context_length_tokens'] for s in samples]

    ranges = [
        ("8K - 16K", 2**13, 2**14),
        ("16K - 32K", 2**14, 2**15),
        ("32K - 64K", 2**15, 2**16),
        ("64K - 128K", 2**16, 2**17),
        ("128K - 256K", 2**17, 2**18),
        ("256K - 512K", 2**18, 2**19),
        ("512K - 1M", 2**19, 2**20),
        ("1M - 2M", 2**20, 2**21),
        ("2M+", 2**21, float('inf')),
    ]

    for label, min_len, max_len in ranges:
        count = sum(1 for l in all_lengths if min_len <= l < max_len)
        if count > 0:
            print(f"  {label:15s}: {count:3d} samples")


def main():
    parser = argparse.ArgumentParser(
        description='Analyze LongBench v2 length categories'
    )
    parser.add_argument(
        '--tokenizer',
        type=str,
        default='meta-llama/Llama-3.1-8B-Instruct',
        help='HuggingFace tokenizer name',
    )

    args = parser.parse_args()

    analyze_length_categories(args.tokenizer)


if __name__ == '__main__':
    main()
