"""
Simple test script for TokenAligner functionality.
Tests the alignment between different tokenizers with various strategies.
"""

import torch
import sys
import os
from pathlib import Path

# Add the project root to the path
sys.path.append(os.path.join(os.path.dirname(__file__), '..'))

from transformers import AutoTokenizer
from rosetta.model.aligner import TokenAligner, AlignmentStrategy


SLM_MODEL_PATH = os.environ.get("SLM_MODEL_PATH", "public/public_models/Qwen3-0.6B")
LLM_MODEL_PATH = os.environ.get("LLM_MODEL_PATH", "public/public_models/gemma-3-1b-it")

def test_basic_alignment():
    """Test basic token alignment functionality"""
    print("=" * 80)
    print("Testing Basic Token Alignment")
    print("=" * 80)
    
    # Load tokenizers (using different model sizes to create tokenization differences)
    slm_tokenizer = AutoTokenizer.from_pretrained(SLM_MODEL_PATH)
    llm_tokenizer = AutoTokenizer.from_pretrained(LLM_MODEL_PATH)
    
    # Create aligner with FIRST strategy
    aligner = TokenAligner(
        slm_tokenizer=slm_tokenizer,
        llm_tokenizer=llm_tokenizer,
        strategy=AlignmentStrategy.FIRST,
        verbose=True
    )
    
    # Test cases
    test_cases = [
        "Hello world!",
        "The quick brown fox jumps over the lazy dog.",
        "Artificial intelligence is transforming the world.",
        "🚀 AI and machine learning! 🤖",  # Emojis
        "Special characters: @#$%^&*()",
    ]
    
    for i, text in enumerate(test_cases, 1):
        print(f"\nTest Case {i}: '{text}'")
        print("-" * 60)
        
        # Get alignment details
        details = aligner.align_sequence(text, return_details=True)
        
        print(f"SLM tokens ({len(details['slm_token_ids'])}): {details['slm_token_ids']}")
        print(f"Aligned LLM tokens ({len(details['aligned_llm_token_ids'])}): {details['aligned_llm_token_ids']}")
        print(f"Original LLM tokens ({len(details['original_llm_token_ids'])}): {details['original_llm_token_ids']}")
        print(f"1-to-1 mapping: {details['one_to_one_count']}/{details['num_tokens']} ({details['one_to_one_rate']:.1%})")
        
        # Show token-by-token mapping
        print("\nToken mapping:")
        for j, (slm_id, llm_id) in enumerate(zip(details['slm_token_ids'], details['aligned_llm_token_ids'])):
            slm_str = details['slm_decoded'][j]
            llm_str = details['aligned_llm_decoded'][j]
            print(f"  [{j:2d}] SLM {slm_id:6d} ('{slm_str}') -> LLM {llm_id:6d} ('{llm_str}')")


def test_strategy_comparison():
    """Compare different alignment strategies"""
    print("\n" + "=" * 80)
    print("Testing Strategy Comparison")
    print("=" * 80)
    
    # Load tokenizers
    slm_tokenizer = AutoTokenizer.from_pretrained(SLM_MODEL_PATH)
    llm_tokenizer = AutoTokenizer.from_pretrained(LLM_MODEL_PATH)
    
    # Create aligners with different strategies
    aligner_first = TokenAligner(
        slm_tokenizer=slm_tokenizer,
        llm_tokenizer=llm_tokenizer,
        strategy=AlignmentStrategy.FIRST,
        verbose=False
    )
    
    aligner_longest = TokenAligner(
        slm_tokenizer=slm_tokenizer,
        llm_tokenizer=llm_tokenizer,
        strategy=AlignmentStrategy.LONGEST,
        verbose=False
    )
    
    # Test text that might have different tokenization
    test_text = "The future of artificial intelligence is bright!"
    
    print(f"Test text: '{test_text}'")
    print("-" * 60)
    
    # Get results from both strategies
    first_result = aligner_first.align_sequence(test_text, return_details=True)
    longest_result = aligner_longest.align_sequence(test_text, return_details=True)
    
    print("FIRST Strategy:")
    print(f"  Aligned tokens: {first_result['aligned_llm_token_ids']}")
    print(f"  1-to-1 rate: {first_result['one_to_one_count']}/{first_result['num_tokens']} ({first_result['one_to_one_rate']:.1%})")
    
    print("\nLONGEST Strategy:")
    print(f"  Aligned tokens: {longest_result['aligned_llm_token_ids']}")
    print(f"  1-to-1 rate: {longest_result['one_to_one_count']}/{longest_result['num_tokens']} ({longest_result['one_to_one_rate']:.1%})")
    
    # Check if strategies produce different results
    if first_result['aligned_llm_token_ids'] != longest_result['aligned_llm_token_ids']:
        print("\n⚠️  Strategies produced different results!")
        print("Differences:")
        for i, (first_id, longest_id) in enumerate(zip(
            first_result['aligned_llm_token_ids'], 
            longest_result['aligned_llm_token_ids']
        )):
            if first_id != longest_id:
                first_str = first_result['aligned_llm_decoded'][i]
                longest_str = longest_result['aligned_llm_decoded'][i]
                print(f"  Position {i}: FIRST={first_id}('{first_str}') vs LONGEST={longest_id}('{longest_str}')")
    else:
        print("\n✅ Both strategies produced identical results")


def test_special_tokens():
    """Test handling of special tokens"""
    print("\n" + "=" * 80)
    print("Testing Special Token Handling")
    print("=" * 80)
    
    # Load tokenizers
    slm_tokenizer = AutoTokenizer.from_pretrained(SLM_MODEL_PATH)
    llm_tokenizer = AutoTokenizer.from_pretrained(LLM_MODEL_PATH)
    
    aligner = TokenAligner(
        slm_tokenizer=slm_tokenizer,
        llm_tokenizer=llm_tokenizer,
        strategy=AlignmentStrategy.FIRST,
        verbose=True
    )
    
    # Test with special tokens
    test_text = "Hello <|im_start|>user<|im_end|> How are you?"
    
    print(f"Test text with special tokens: '{test_text}'")
    print("-" * 60)
    
    details = aligner.align_sequence(test_text, return_details=True)
    
    print(f"1-to-1 mapping: {details['one_to_one_count']}/{details['num_tokens']} ({details['one_to_one_rate']:.1%})")
    print("Token breakdown:")
    for i, (slm_id, llm_id) in enumerate(zip(details['slm_token_ids'], details['aligned_llm_token_ids'])):
        slm_str = details['slm_decoded'][i]
        llm_str = details['aligned_llm_decoded'][i]
        is_special = slm_id in slm_tokenizer.all_special_ids
        print(f"  [{i:2d}] SLM {slm_id:6d} ('{slm_str}') -> LLM {llm_id:6d} ('{llm_str}') {'[SPECIAL]' if is_special else ''}")


def test_performance():
    """Test alignment performance with caching"""
    print("\n" + "=" * 80)
    print("Testing Performance and Caching")
    print("=" * 80)
    
    # Load tokenizers
    slm_tokenizer = AutoTokenizer.from_pretrained(SLM_MODEL_PATH)
    llm_tokenizer = AutoTokenizer.from_pretrained(LLM_MODEL_PATH)
    
    aligner = TokenAligner(
        slm_tokenizer=slm_tokenizer,
        llm_tokenizer=llm_tokenizer,
        strategy=AlignmentStrategy.FIRST,
        verbose=False
    )
    
    # Test text
    test_text = "This is a performance test for token alignment."
    
    import time
    
    # First run (no cache)
    start_time = time.time()
    for _ in range(10):
        aligner.align_sequence(test_text)
    first_run_time = time.time() - start_time
    
    # Second run (with cache)
    start_time = time.time()
    for _ in range(10):
        aligner.align_sequence(test_text)
    second_run_time = time.time() - start_time
    
    print(f"Test text: '{test_text}'")
    print(f"First run (10 iterations): {first_run_time:.4f} seconds")
    print(f"Second run (10 iterations, cached): {second_run_time:.4f} seconds")
    print(f"Cache speedup: {first_run_time / second_run_time:.2f}x")


def test_visualization():
    """Test the visualization functionality"""
    print("\n" + "=" * 80)
    print("Testing Visualization")
    print("=" * 80)
    
    # Load tokenizers
    slm_tokenizer = AutoTokenizer.from_pretrained(SLM_MODEL_PATH)
    llm_tokenizer = AutoTokenizer.from_pretrained(LLM_MODEL_PATH)
    
    # Test with FIRST strategy
    aligner_first = TokenAligner(
        slm_tokenizer=slm_tokenizer,
        llm_tokenizer=llm_tokenizer,
        strategy=AlignmentStrategy.FIRST,
        verbose=True
    )
    
    # Test with LONGEST strategy
    aligner_longest = TokenAligner(
        slm_tokenizer=slm_tokenizer,
        llm_tokenizer=llm_tokenizer,
        strategy=AlignmentStrategy.LONGEST,
        verbose=True
    )
    
    # Test text
    test_text = "The quick brown fox jumps over the lazy dog."
    
    print("FIRST Strategy Visualization:")
    aligner_first.visualize_alignment(test_text)
    
    print("\nLONGEST Strategy Visualization:")
    aligner_longest.visualize_alignment(test_text)


def test_messages_mode():
    """Test messages-mode alignment with chat templates and boundary padding"""
    print("\n" + "=" * 80)
    print("Testing Messages Mode Alignment")
    print("=" * 80)
    
    slm_tokenizer = AutoTokenizer.from_pretrained(SLM_MODEL_PATH)
    llm_tokenizer = AutoTokenizer.from_pretrained(LLM_MODEL_PATH)
    aligner = TokenAligner(
        slm_tokenizer=slm_tokenizer,
        llm_tokenizer=llm_tokenizer,
        strategy=AlignmentStrategy.FIRST,
        verbose=False
    )
    
    messages = [
        {"role": "system", "content": "You are a helpful assistant."},
        {"role": "user", "content": "Explain transformers briefly."},
        {"role": "assistant", "content": "Transformers are a type of neural network architecture that are used for natural language processing."}
    ]
    details = aligner.align_chat_messages(messages, add_generation_prompt=False, return_details=True, remove_last_surfix=True)
    
    sections = details.get('sections', [])
    print(f"Sections (type, slm_range, llm_range): {[(s['type'], s['slm_range'], s['llm_range']) for s in sections]}")
    print(f"Total lengths -> SLM: {len(details['slm_ids_padded'])}, LLM: {len(details['llm_ids_padded'])}")
    
    # Message mask stats for SLM padded ids
    slm_mask = details.get('message_mask', [])
    print(f"Message mask true count: {sum(1 for v in slm_mask if v)}/{len(slm_mask)}")

    slm_padding_mask = details.get('slm_padding_mask', [])
    print(f"SLM padding mask true count: {sum(1 for v in slm_padding_mask if v)}/{len(slm_padding_mask)}")
    llm_padding_mask = details.get('llm_padding_mask', [])
    print(f"LLM padding mask true count: {sum(1 for v in llm_padding_mask if v)}/{len(llm_padding_mask)}")
    
    # Decode message-only portion from SLM padded ids
    slm_ids_padded = details['slm_ids_padded']
    slm_message_ids = [tid for tid, is_msg in zip(slm_ids_padded, slm_mask) if is_msg]
    slm_message_text = slm_tokenizer.decode(
        slm_message_ids,
        skip_special_tokens=False,
        clean_up_tokenization_spaces=False
    )
    print(f"SLM message-only text: '{slm_message_text}'")

    # LLM message decoding using shared mask
    llm_ids_padded = details['llm_ids_padded']
    llm_message_ids = [tid for tid, is_msg in zip(llm_ids_padded, slm_mask) if is_msg]
    llm_message_text = llm_tokenizer.decode(
        llm_message_ids,
        skip_special_tokens=False,
        clean_up_tokenization_spaces=False
    )
    print(f"LLM message-only text: '{llm_message_text}'")

    padded_slm_text = slm_tokenizer.decode(details['slm_ids_padded'], ignore_special_tokens=False)
    padded_llm_text = llm_tokenizer.decode(details['llm_ids_padded'], ignore_special_tokens=False)
    print(f"Padded SLM text: '{padded_slm_text}'")
    print(f"Padded LLM text: '{padded_llm_text}'")


def main():
    """Run all tests"""
    print("TokenAligner Test Suite")
    print("=" * 80)
    
    try:
        # Run all tests
        test_basic_alignment()
        test_strategy_comparison()
        test_special_tokens()
        test_performance()
        test_visualization()
        test_messages_mode()
        
        print("\n" + "=" * 80)
        print("✅ All tests completed successfully!")
        print("=" * 80)
        
    except Exception as e:
        print(f"\n❌ Test failed with error: {e}")
        import traceback
        traceback.print_exc()


if __name__ == "__main__":
    main()
