"""
Train a BPE tokenizer on the longmino PDF dataset.
"""

from datasets import load_from_disk
from tokenizers import Tokenizer, models, pre_tokenizers, trainers, decoders
from pathlib import Path
import random
from tqdm import tqdm

# Configuration
DATASET_PATH = "science_tech_2e15_dataset"
VOCAB_SIZE = 2**15  # 32k
NUM_TRAIN_SAMPLES = 10_000
NUM_TEST_SAMPLES = 1_000
OUTPUT_PATH = Path(f"tokenizer_pdf_{VOCAB_SIZE//1024}k.json")

# Special tokens
SPECIAL_TOKENS = ["<|endoftext|>"]

print("="*80)
print(f"Training BPE Tokenizer on PDF Dataset")
print("="*80)
print(f"Dataset: {DATASET_PATH}")
print(f"Vocab size: {VOCAB_SIZE:,}")
print(f"Training samples: {NUM_TRAIN_SAMPLES:,}")
print(f"Test samples: {NUM_TEST_SAMPLES:,}")
print(f"Special tokens: {SPECIAL_TOKENS}")

# Load dataset
print(f"\nLoading dataset from {DATASET_PATH}...")
ds = load_from_disk(DATASET_PATH)
print(f"  Total documents: {len(ds):,}")

# Sample training and test documents
print(f"\nSampling documents...")
random.seed(42)

total_docs = len(ds)
all_indices = list(range(total_docs))
random.shuffle(all_indices)

train_indices = all_indices[:NUM_TRAIN_SAMPLES]
test_indices = all_indices[NUM_TRAIN_SAMPLES:NUM_TRAIN_SAMPLES + NUM_TEST_SAMPLES]

print(f"  Training indices: {len(train_indices):,}")
print(f"  Test indices: {len(test_indices):,}")

# Extract text
print(f"\nExtracting text...")
train_texts = [ds[i]['text'] for i in train_indices]
test_texts = [ds[i]['text'] for i in test_indices]

train_total_chars = sum(len(text) for text in train_texts)
test_total_chars = sum(len(text) for text in test_texts)

print(f"  Training text: {train_total_chars:,} chars ({train_total_chars/1e6:.1f}M)")
print(f"  Test text: {test_total_chars:,} chars ({test_total_chars/1e6:.1f}M)")

# Initialize tokenizer
print(f"\nInitializing BPE tokenizer...")
tokenizer = Tokenizer(models.BPE())

# No normalizer (preserve text as-is)
tokenizer.normalizer = None

# ByteLevel pre-tokenizer
tokenizer.pre_tokenizer = pre_tokenizers.ByteLevel(add_prefix_space=False)

# ByteLevel decoder
tokenizer.decoder = decoders.ByteLevel()

# Train
print(f"\nTraining tokenizer (this may take a few minutes)...")
trainer = trainers.BpeTrainer(
    vocab_size=VOCAB_SIZE,
    special_tokens=SPECIAL_TOKENS,
    show_progress=True,
)

tokenizer.train_from_iterator(train_texts, trainer=trainer)

print(f"✓ Training complete!")
print(f"  Final vocab size: {tokenizer.get_vocab_size():,}")

# Verify special tokens
print(f"\nSpecial tokens:")
for token in SPECIAL_TOKENS:
    token_id = tokenizer.token_to_id(token)
    print(f"  {token}: ID {token_id}")

# Evaluate on training data
print(f"\n{'='*80}")
print(f"Evaluating on Training Data:")
print(f"{'='*80}")
train_eval_texts = train_texts[:len(test_texts)]
train_eval_total_chars = sum(len(text) for text in train_eval_texts)
train_eval_total_tokens = 0
for text in tqdm(train_eval_texts):
    tokens = tokenizer.encode(text).ids
    train_eval_total_tokens += len(tokens)

train_eval_chars_per_token = train_eval_total_chars / train_eval_total_tokens
print(f"  Total chars: {train_eval_total_chars:,}")
print(f"  Total tokens: {train_eval_total_tokens:,}")
print(f"  Chars per token: {train_eval_chars_per_token:.2f}")

# Evaluate on test data
print(f"\n{'='*80}")
print(f"Evaluating on Test Data:")
print(f"{'='*80}")

test_total_tokens = 0
for text in tqdm(test_texts):
    tokens = tokenizer.encode(text).ids
    test_total_tokens += len(tokens)

test_chars_per_token = test_total_chars / test_total_tokens
print(f"  Total chars: {test_total_chars:,}")
print(f"  Total tokens: {test_total_tokens:,}")
print(f"  Chars per token: {test_chars_per_token:.2f}")

# Compare
gap = train_eval_chars_per_token - test_chars_per_token
gap_pct = (gap / test_chars_per_token) * 100

print(f"\n{'='*80}")
print(f"Training vs Test Comparison:")
print(f"{'='*80}")
print(f"  Train chars/token: {train_eval_chars_per_token:.2f}")
print(f"  Test chars/token:  {test_chars_per_token:.2f}")
print(f"  Gap: {gap:+.2f} ({gap_pct:+.1f}%)")

if abs(gap_pct) < 5:
    print(f"  ✓ Excellent generalization!")
elif gap_pct >= 5:
    print(f"  ⚠ Some overfitting detected")
elif gap_pct < -5:
    print(f"  ⚠ Test performs better than train (unusual)")

# Save tokenizer
print(f"\n{'='*80}")
print(f"Saving tokenizer to: {OUTPUT_PATH}")
print(f"{'='*80}")

tokenizer.save(str(OUTPUT_PATH))
print(f"✓ Saved!")

# Test decoding
print(f"\nTesting decode with special tokens...")
test_text = "Hello world!<|endoftext|>"
encoded = tokenizer.encode(test_text)
decoded = tokenizer.decode(encoded.ids, skip_special_tokens=False)
print(f"  Original: {test_text}")
print(f"  Token IDs: {encoded.ids}")
print(f"  Decoded: {decoded}")

print(f"\n{'='*80}")
print(f"Done! Tokenizer ready for use.")
print(f"{'='*80}")
