#!/usr/bin/env python3
"""Compare tokenization between direct Gemma tokenizer and tokenize_function."""

import sys
import os
sys.path.append(os.path.dirname(os.path.dirname(os.path.abspath(__file__))))

from transformers import AutoTokenizer
from experiment_utils import tokenize_function, apply_chat_template
from dataclasses import dataclass

@dataclass
class MockConfig:
    max_seq_length: int = 64

def decode_token_id(tokenizer, token_id):
    if token_id == -100:
        return 'MASKED'
    decoded = tokenizer.decode([token_id], skip_special_tokens=False)
    if decoded == "\n":
        return "\\n"
    return decoded

def compare_tokenizations():
    # Setup
    tokenizer = AutoTokenizer.from_pretrained("google/gemma-2b-it")
    config = MockConfig()
    
    # Create conversation
    messages = [
        {"role": "user", "content": "What is 2+2?"},
        {"role": "assistant", "content": "2+2 equals 4."}
    ]
    
    # Apply chat template
    examples = {"messages": [messages]}
    templated = apply_chat_template(examples, tokenizer)
    text = templated['text'][0]
    
    print(f"Input text: {repr(text)}\n")
    
    # Direct tokenization (what Gemma tokenizer produces)
    orig_tokenized = tokenizer(text, truncation=True, padding="max_length", max_length=config.max_seq_length)
    
    # Tokenize with mask_only_assistant_reply
    tokenized_funct = tokenize_function(templated, tokenizer, config, mask_only_assistant_reply=True)

    print(tokenized_funct['input_ids'])
    
    # Extract values from masked output
    tokenized_funct_input_ids = tokenized_funct['input_ids'][0]
    tokenized_funct_attention_mask = tokenized_funct['attention_mask'][0]
    tokenized_funct_labels = tokenized_funct['labels']
    
    # Compare each token
    for i in range(len(orig_tokenized['input_ids'])):
        str_tokenized_funct_label = decode_token_id(tokenizer, tokenized_funct_labels[i])
        print(f"{i:3d} | {decode_token_id(tokenizer, orig_tokenized['input_ids'][i]):>20} | {decode_token_id(tokenizer, tokenized_funct_input_ids[i]):>20} | {str_tokenized_funct_label:>20} | {orig_tokenized['attention_mask'][i]:>10} | {tokenized_funct_attention_mask[i]:>10}")

if __name__ == "__main__":
    compare_tokenizations()