"""
Analyze tokenizer vocabulary: special tokens, common tokens, and samples
at exponentially increasing rarity levels.
"""

from tokenizers import Tokenizer
from pathlib import Path

TOKENIZER_PATH = "tokenizer_pdf_32k.json"

print("="*80)
print("Tokenizer Vocabulary Analysis")
print("="*80)

# Load tokenizer
print(f"\nLoading tokenizer: {TOKENIZER_PATH}")
tokenizer = Tokenizer.from_file(TOKENIZER_PATH)
vocab_size = tokenizer.get_vocab_size()
print(f"  Vocab size: {vocab_size:,}")

# Get vocabulary
vocab = tokenizer.get_vocab()
id_to_token = {v: k for k, v in vocab.items()}

# Special tokens (typically IDs 0-7 in our setup)
print(f"\n{'='*80}")
print("Special Tokens:")
print("="*80)

special_token_names = ["<|endoftext|>", "<reponame>", "<filename>", "<gh_stars>",
                       "<fim_prefix>", "<fim_middle>", "<fim_suffix>", "<fim_pad>"]

for token_name in special_token_names:
    token_id = tokenizer.token_to_id(token_name)
    if token_id is not None:
        print(f"  ID {token_id:5d}: {token_name}")

# Find all special tokens by checking low IDs
print(f"\n  Low IDs (0-20):")
for i in range(21):
    if i in id_to_token:
        token = id_to_token[i]
        # Show representation
        print(f"    ID {i:5d}: '{token}'")

# Most common 20 tokens (after special tokens)
# In BPE, lower IDs generally = more common
print(f"\n{'='*80}")
print("Most Common 20 Tokens (by ID):")
print("="*80)

start_id = 8  # Skip special tokens
for i in range(start_id, min(start_id + 20, vocab_size)):
    if i in id_to_token:
        token = id_to_token[i]
        print(f"  ID {i:5d}: '{token}'")

# Exponentially increasing rarity levels
print(f"\n{'='*80}")
print("Sample Tokens at Exponentially Increasing Rarity:")
print("="*80)

# Powers of 2 from 2^5 to 2^15
powers = [5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15]

for power in powers:
    base_id = 2 ** power

    # Skip if beyond vocab
    if base_id >= vocab_size:
        continue

    print(f"\n  Around 2^{power} = {base_id:,}:")

    # Sample 5 tokens around this position
    sample_ids = []
    for offset in range(-2, 3):  # -2, -1, 0, 1, 2
        sample_id = base_id + offset
        if 0 <= sample_id < vocab_size:
            sample_ids.append(sample_id)

    for sample_id in sample_ids:
        if sample_id in id_to_token:
            token = id_to_token[sample_id]
            # Truncate long tokens for display
            display_token = token if len(token) <= 50 else token[:47] + "..."
            print(f"    ID {sample_id:5d}: '{display_token}'")

# Show some stats about the vocab
print(f"\n{'='*80}")
print("Vocabulary Statistics:")
print("="*80)

# Count token types
single_char_tokens = 0
multi_char_tokens = 0
whitespace_tokens = 0

for token_id, token in id_to_token.items():
    if len(token) == 1:
        single_char_tokens += 1
        if token.isspace():
            whitespace_tokens += 1
    else:
        multi_char_tokens += 1

print(f"  Single-character tokens: {single_char_tokens:,}")
print(f"  Multi-character tokens: {multi_char_tokens:,}")
print(f"  Whitespace tokens: {whitespace_tokens:,}")

# Show longest tokens
print(f"\n  Longest 5 tokens:")
sorted_by_length = sorted(id_to_token.items(), key=lambda x: len(x[1]), reverse=True)
for token_id, token in sorted_by_length[:5]:
    print(f"    ID {token_id:5d} ({len(token):3d} chars): '{token[:100]}'")

print(f"\n{'='*80}")
print("Done!")
print("="*80)
