"""
Fix tokenizer decoder to properly decode special tokens.
"""

from tokenizers import Tokenizer, decoders

VOCAB = "32k"
TOKENIZER_PATH = f"tokenizer_{VOCAB}.json"

print("="*80)
print("Fixing Tokenizer Decoder")
print("="*80)

# Load tokenizer
print(f"\nLoading tokenizer: {TOKENIZER_PATH}")
tokenizer = Tokenizer.from_file(TOKENIZER_PATH)

print(f"Current decoder: {tokenizer.decoder}")

# Add ByteLevel decoder (matches the ByteLevel pre-tokenizer used during training)
print(f"\nSetting ByteLevel decoder...")
tokenizer.decoder = decoders.ByteLevel()

print(f"New decoder: {tokenizer.decoder}")

# Test it
test_ids = [1, 641, 22, 6401, 2, 641, 21, 703, 206, 694, 0]
print(f"\nTest token IDs: {test_ids}")

# Decode each individually
print("\nIndividual decoding:")
for tid in test_ids:
    token = tokenizer.id_to_token(tid)
    decoded = tokenizer.decode([tid], skip_special_tokens=False)
    print(f"  ID {tid:4d} -> '{token:20s}' -> '{decoded}'")

# Decode full sequence
print("\nFull sequence:")
decoded = tokenizer.decode(test_ids, skip_special_tokens=False)
print(f"  Result: '{decoded}'")

# Save fixed tokenizer
output_path = TOKENIZER_PATH + "_fixed"
print(f"\nSaving fixed tokenizer to: {output_path}")
tokenizer.save(output_path)

print("\nDone! Tokenizer now has proper decoder configured.")
