#!/usr/bin/env python3
"""
Test script to verify the conversion process works correctly.
This script tests the conversion functions without loading the full checkpoint.
"""

import sys
import os
from pathlib import Path

# Add the current directory to Python path
sys.path.append(str(Path(__file__).parent))

try:
    from convert_dcp_to_hf import create_hf_config_from_torchtitan, convert_torchtitan_to_hf_state_dict
    print("✓ Successfully imported conversion functions")
except ImportError as e:
    print(f"✗ Failed to import conversion functions: {e}")
    sys.exit(1)

try:
    import torch
    import torch.distributed.checkpoint as dcp
    from transformers import LlamaConfig, LlamaForCausalLM
    print("✓ All required dependencies are available")
except ImportError as e:
    print(f"✗ Missing dependency: {e}")
    print("Please install: pip install transformers torch")
    sys.exit(1)

def create_dummy_state_dict():
    """Create a dummy TorchTitan state dict for testing"""
    n_layers = 4
    hidden_size = 512
    vocab_size = 1000
    intermediate_size = 1376  # ~8/3 * hidden_size
    
    state_dict = {}
    
    # Embedding and output
    state_dict["tok_embeddings.weight"] = torch.randn(vocab_size, hidden_size)
    state_dict["output.weight"] = torch.randn(vocab_size, hidden_size)
    state_dict["norm.weight"] = torch.randn(hidden_size)
    
    # Frequency buffer (will be skipped in conversion)
    state_dict["freqs_cis"] = torch.randn(1024, hidden_size // 8, dtype=torch.complex64)
    
    # Transformer layers
    for i in range(n_layers):
        # Attention weights
        state_dict[f"layers.{i}.attention.wq.weight"] = torch.randn(hidden_size, hidden_size)
        state_dict[f"layers.{i}.attention.wk.weight"] = torch.randn(hidden_size, hidden_size)
        state_dict[f"layers.{i}.attention.wv.weight"] = torch.randn(hidden_size, hidden_size)
        state_dict[f"layers.{i}.attention.wo.weight"] = torch.randn(hidden_size, hidden_size)
        
        # FFN weights
        state_dict[f"layers.{i}.feed_forward.w1.weight"] = torch.randn(intermediate_size, hidden_size)
        state_dict[f"layers.{i}.feed_forward.w2.weight"] = torch.randn(hidden_size, intermediate_size)
        state_dict[f"layers.{i}.feed_forward.w3.weight"] = torch.randn(intermediate_size, hidden_size)
        
        # Layer norms
        state_dict[f"layers.{i}.attention_norm.weight"] = torch.randn(hidden_size)
        state_dict[f"layers.{i}.ffn_norm.weight"] = torch.randn(hidden_size)
    
    return state_dict

def test_config_creation():
    """Test configuration creation from state dict"""
    print("\n=== Testing Config Creation ===")
    
    dummy_state_dict = create_dummy_state_dict()
    
    try:
        config = create_hf_config_from_torchtitan(dummy_state_dict)
        print(f"✓ Config created successfully")
        print(f"  - vocab_size: {config.vocab_size}")
        print(f"  - hidden_size: {config.hidden_size}")
        print(f"  - num_hidden_layers: {config.num_hidden_layers}")
        print(f"  - num_attention_heads: {config.num_attention_heads}")
        return True
    except Exception as e:
        print(f"✗ Failed to create config: {e}")
        return False

def test_state_dict_conversion():
    """Test state dict conversion"""
    print("\n=== Testing State Dict Conversion ===")
    
    dummy_state_dict = create_dummy_state_dict()
    
    try:
        hf_state_dict = convert_torchtitan_to_hf_state_dict(dummy_state_dict)
        print(f"✓ State dict converted successfully")
        print(f"  - Original keys: {len(dummy_state_dict)}")
        print(f"  - Converted keys: {len(hf_state_dict)}")
        
        # Check some key mappings
        expected_mappings = [
            ("tok_embeddings.weight", "model.embed_tokens.weight"),
            ("layers.0.attention.wq.weight", "model.layers.0.self_attn.q_proj.weight"),
            ("norm.weight", "model.norm.weight"),
            ("output.weight", "lm_head.weight")
        ]
        
        for tt_key, hf_key in expected_mappings:
            if tt_key in dummy_state_dict and hf_key in hf_state_dict:
                print(f"  ✓ {tt_key} -> {hf_key}")
            else:
                print(f"  ✗ Missing mapping: {tt_key} -> {hf_key}")
                return False
        
        # Check that freqs_cis was skipped
        if "freqs_cis" not in hf_state_dict:
            print(f"  ✓ freqs_cis buffer correctly skipped")
        else:
            print(f"  ✗ freqs_cis buffer should have been skipped")
            return False
            
        return True
    except Exception as e:
        print(f"✗ Failed to convert state dict: {e}")
        return False

def test_hf_model_creation():
    """Test HuggingFace model creation and loading"""
    print("\n=== Testing HF Model Creation ===")
    
    dummy_state_dict = create_dummy_state_dict()
    
    try:
        # Create config
        config = create_hf_config_from_torchtitan(dummy_state_dict)
        
        # Convert state dict
        hf_state_dict = convert_torchtitan_to_hf_state_dict(dummy_state_dict)
        
        # Create HF model
        model = LlamaForCausalLM(config)
        
        # Load weights
        missing_keys, unexpected_keys = model.load_state_dict(hf_state_dict, strict=False)
        
        print(f"✓ HuggingFace model created and loaded successfully")
        if missing_keys:
            print(f"  - Missing keys: {len(missing_keys)} (expected for position embeddings)")
        if unexpected_keys:
            print(f"  - Unexpected keys: {len(unexpected_keys)}")
        
        # Test forward pass
        input_ids = torch.randint(0, config.vocab_size, (1, 10))
        with torch.no_grad():
            output = model(input_ids)
            print(f"  ✓ Forward pass successful, output shape: {output.logits.shape}")
        
        return True
    except Exception as e:
        print(f"✗ Failed to create HF model: {e}")
        return False

def test_checkpoint_path_validation():
    """Test that the actual checkpoint path exists"""
    print("\n=== Testing Checkpoint Path ===")
    
    checkpoint_path = "/data"
    
    if os.path.exists(checkpoint_path):
        print(f"✓ Checkpoint path exists: {checkpoint_path}")
        
        # Check for .distcp files
        distcp_files = [f for f in os.listdir(checkpoint_path) if f.endswith('.distcp')]
        print(f"  - Found {len(distcp_files)} .distcp files")
        
        # Check for metadata
        metadata_path = os.path.join(checkpoint_path, ".metadata")
        if os.path.exists(metadata_path):
            print(f"  ✓ Metadata file exists")
        else:
            print(f"  ✗ Metadata file missing")
            
        return True
    else:
        print(f"✗ Checkpoint path does not exist: {checkpoint_path}")
        print("  Please update the path in example_convert.sh")
        return False

def main():
    print("TorchTitan to HuggingFace Conversion Test")
    print("=" * 50)
    
    tests = [
        test_config_creation,
        test_state_dict_conversion,
        test_hf_model_creation,
        test_checkpoint_path_validation
    ]
    
    results = []
    for test in tests:
        try:
            result = test()
            results.append(result)
        except Exception as e:
            print(f"✗ Test failed with exception: {e}")
            results.append(False)
    
    print("\n" + "=" * 50)
    print("Test Summary:")
    print(f"Passed: {sum(results)}/{len(results)}")
    
    if all(results):
        print("✓ All tests passed! You can now run the conversion.")
        print("\nTo convert your checkpoint, run:")
        print("  cd LongSee/torchtitan_longsee/convert_to_HF")
        print("  ./example_convert.sh")
    else:
        print("✗ Some tests failed. Please check the errors above.")
        return 1
    
    return 0

if __name__ == "__main__":
    sys.exit(main()) 