"""
Simple usage example of MLPProjector for unified memory
"""

import torch
import sys
import os

# Add the project root to the path
sys.path.append(os.path.join(os.path.dirname(__file__), '..'))

from rosetta.model.projector import create_projector
from transformers import DynamicCache


def test_mlp_projector():
    """Test MLPProjector with basic configuration"""
    print("Testing MLPProjector")
    print("=" * 40)
    
    # Create MLPProjector
    projector = create_projector(
        "MLPProjector",
        source_dim=768,
        target_dim=1024,
        hidden_dim=512,
        num_layers=2
    )
    
    # Create sample tensors
    batch_size, seq_len, num_heads = 2, 8, 12
    source_tensors = torch.randn(batch_size, seq_len, num_heads, 768)
    target_tensors = torch.randn(batch_size, seq_len, num_heads, 1024)
    
    print(f"Source shape: {source_tensors.shape}")
    print(f"Target shape: {target_tensors.shape}")
    
    # Forward pass
    with torch.no_grad():
        output = projector(source_tensors, target_tensors)
    
    print(f"Output shape: {output.shape}")
    print(f"Parameters: {sum(p.numel() for p in projector.parameters()):,}")
    print()


def test_kv_cache():
    """Test KV cache projection"""
    print("Testing KV Cache Projection")
    print("=" * 40)
    
    # Create projector for same dimensions
    projector = create_projector(
        "MLPProjector",
        source_dim=64,
        target_dim=64,
        hidden_dim=128,
        num_layers=2
    )
    
    # Create sample caches
    batch_size, num_heads, seq_len, head_dim = 2, 8, 10, 64
    
    # Source cache
    source_cache = DynamicCache()
    key = torch.randn(batch_size, num_heads, seq_len, head_dim)
    value = torch.randn(batch_size, num_heads, seq_len, head_dim)
    source_cache.update(key, value, 0)
    
    # Target cache
    target_cache = DynamicCache()
    key = torch.randn(batch_size, num_heads, seq_len, head_dim)
    value = torch.randn(batch_size, num_heads, seq_len, head_dim)
    target_cache.update(key, value, 0)
    
    print(f"Cache key shape: {source_cache.key_cache[0].shape}")
    print(f"Cache value shape: {source_cache.value_cache[0].shape}")
    
    # Project cache
    with torch.no_grad():
        projected_cache = projector.cache_project(source_cache, target_cache)
    
    print(f"Projected key shape: {projected_cache.key_cache[0].shape}")
    print(f"Projected value shape: {projected_cache.value_cache[0].shape}")
    print()


def main():
    """Run simple projector tests"""
    print("Simple MLPProjector Usage Test")
    print("=" * 50)
    print()
    
    test_mlp_projector()
    test_kv_cache()
    
    print("✅ All tests completed successfully!")


if __name__ == "__main__":
    main() 