"""
Pre-tokenize the PDF dataset and save as binary files.
Creates:
- train.bin: tokens as uint16
- train_offsets.bin: document start positions as uint64
"""

from datasets import load_from_disk
from tokenizers import Tokenizer
from pathlib import Path
import numpy as np
from tqdm import tqdm

# Configuration
DATASET_PATH = "science_tech_2e16_dataset"
TOKENIZER_PATH = "tokenizer_pdf_32k.json"
OUTPUT_DIR = Path("pretokenized_pdfs")
OUTPUT_PREFIX = "science_tech_2e16"

print("="*80)
print("Pre-tokenize PDF Dataset")
print("="*80)
print(f"Dataset: {DATASET_PATH}")
print(f"Tokenizer: {TOKENIZER_PATH}")
print(f"Output: {OUTPUT_DIR / OUTPUT_PREFIX}")

# Load dataset
print(f"\nLoading dataset...")
ds = load_from_disk(DATASET_PATH)
print(f"  Total documents: {len(ds):,}")

# Load tokenizer
print(f"\nLoading tokenizer...")
tokenizer = Tokenizer.from_file(TOKENIZER_PATH)
print(f"  Vocab size: {tokenizer.get_vocab_size():,}")

# Get EOS token ID
eos_token = "<|endoftext|>"
eos_token_id = tokenizer.token_to_id(eos_token)
print(f"  EOS token '{eos_token}': ID {eos_token_id}")

if eos_token_id is None:
    print(f"ERROR: EOS token not found in tokenizer!")
    exit(1)

# Prepare output
OUTPUT_DIR.mkdir(exist_ok=True)
tokens_path = OUTPUT_DIR / f"{OUTPUT_PREFIX}_train.bin"
offsets_path = OUTPUT_DIR / f"{OUTPUT_PREFIX}_train_offsets.bin"

# Tokenize all documents in parallel
from concurrent.futures import ProcessPoolExecutor
import os

num_workers = os.cpu_count() // 2
print(f"\nTokenizing documents with {num_workers} workers...")

def tokenize_doc(doc_text):
    """Tokenize a single document. Returns list of token IDs including EOS."""
    tokens = tokenizer.encode(doc_text).ids
    tokens.append(eos_token_id)
    return tokens

# Process in batches to maintain order
batch_size = 1000
all_tokens = []
offsets = [0]

with ProcessPoolExecutor(max_workers=num_workers) as executor:
    for batch_start in tqdm(range(0, len(ds), batch_size), desc="Processing batches", unit="batch"):
        batch_end = min(batch_start + batch_size, len(ds))
        batch = ds.select(range(batch_start, batch_end))

        # Tokenize batch in parallel
        batch_texts = [doc['text'] for doc in batch]
        batch_tokens = list(executor.map(tokenize_doc, batch_texts))

        # Accumulate in order
        for tokens in batch_tokens:
            all_tokens.extend(tokens)
            offsets.append(len(all_tokens))

print(f"\n✓ Tokenization complete!")
print(f"  Total tokens: {len(all_tokens):,} ({len(all_tokens)/1e9:.2f}B)")
print(f"  Total sequences: {len(offsets) - 1:,}")
print(f"  Average tokens per doc: {len(all_tokens)/(len(offsets)-1):,.0f}")

# Write tokens
print(f"\nWriting tokens to: {tokens_path}")
tokens_array = np.array(all_tokens, dtype=np.uint16)
tokens_array.tofile(tokens_path)
tokens_size_mb = tokens_path.stat().st_size / (1024**2)
print(f"  Size: {tokens_size_mb:.1f} MB")

# Write offsets
print(f"\nWriting offsets to: {offsets_path}")
offsets_array = np.array(offsets, dtype=np.uint64)
offsets_array.tofile(offsets_path)
offsets_size_mb = offsets_path.stat().st_size / (1024**2)
print(f"  Size: {offsets_size_mb:.1f} MB")

print(f"\n{'='*80}")
print("Done! Files created:")
print(f"  {tokens_path}")
print(f"  {offsets_path}")
print(f"\nUsage in training:")
print(f"  tokens = np.memmap('{tokens_path}', dtype=np.uint16, mode='r')")
print(f"  offsets = np.memmap('{offsets_path}', dtype=np.uint64, mode='r')")
print(f"  # Get document i: tokens[offsets[i]:offsets[i+1]]")
print("="*80)
