"""Prepare PG-19 dataset with Llama tokenization."""

import argparse
from pathlib import Path
import pickle

from datasets import load_dataset
from transformers import AutoTokenizer
from tqdm import tqdm


def prepare_pg19_dataset(
    output_dir: str,
    tokenizer_name: str = 'meta-llama/Llama-3.2-1B',
    max_length: int = 2048,
    split: str = 'test',
):
    """Prepare PG-19 dataset with Llama tokenization.

    Args:
        output_dir: Output directory for processed data
        tokenizer_name: HuggingFace tokenizer name
        max_length: Maximum sequence length
        split: Dataset split to process
    """
    output_path = Path(output_dir)
    output_path.mkdir(parents=True, exist_ok=True)

    # Load tokenizer
    print(f"Loading tokenizer: {tokenizer_name}")
    tokenizer = AutoTokenizer.from_pretrained(tokenizer_name)

    # Load dataset
    print(f"Loading PG-19 dataset (split: {split})...")
    dataset = load_dataset('emozilla/pg19', split=split, trust_remote_code=True, num_proc=32)

    # Process and tokenize
    tokenized_samples = []
    total_tokens = 0

    for example in tqdm(dataset, desc="Processing"):
        text = example['text']

        # Tokenize
        tokens = tokenizer.encode(text, add_special_tokens=False)
        total_tokens += len(tokens)

        # Split into chunks
        for i in range(0, len(tokens), max_length):
            chunk = tokens[i:i + max_length]
            if len(chunk) == max_length:
                tokenized_samples.append(chunk)

    print(f"\nDataset statistics:")
    print(f"  Total samples: {len(tokenized_samples)}")
    print(f"  Total tokens: {total_tokens:,}")
    print(f"  Avg tokens per sample: {total_tokens / len(tokenized_samples):.1f}")
    print(f"  Sequence length: {max_length}")

    # Save processed data
    output_file = output_path / f"pg19_{split}_{max_length}.pkl"
    print(f"\nSaving to {output_file}")

    with open(output_file, 'wb') as f:
        pickle.dump({
            'samples': tokenized_samples,
            'tokenizer_name': tokenizer_name,
            'max_length': max_length,
            'split': split,
        }, f)

    print("Done!")


def main():
    parser = argparse.ArgumentParser(description='Prepare PG-19 dataset')
    parser.add_argument(
        '--output_dir',
        type=str,
        default='data/processed',
        help='Output directory',
    )
    parser.add_argument(
        '--tokenizer',
        type=str,
        default='meta-llama/Llama-3.2-1B',
        help='Tokenizer name',
    )
    parser.add_argument(
        '--max_length',
        type=int,
        default=2048,
        help='Maximum sequence length',
    )
    parser.add_argument(
        '--split',
        type=str,
        default='test',
        choices=['train', 'test', 'validation'],
        help='Dataset split',
    )

    args = parser.parse_args()

    prepare_pg19_dataset(
        output_dir=args.output_dir,
        tokenizer_name=args.tokenizer,
        max_length=args.max_length,
        split=args.split,
    )


if __name__ == '__main__':
    main()
