"""
Test script for TwoStageRosetta implementation.
"""

import sys
import os
sys.path.append(os.path.join(os.path.dirname(__file__), '..', '..'))

from rosetta.baseline.multi_stage import TwoStageRosetta
import torch


def test_initialization():
    """Test that TwoStageRosetta can be initialized with proper config."""
    print("Testing TwoStageRosetta initialization...")
    
    # Mock configuration for testing
    rosetta_config = {
        "base_model": "Qwen/Qwen3-0.6B",
        "teacher_model": "Qwen/Qwen3-4B",
        "checkpoints_dir": "/tmp/test_checkpoints",  # Dummy path
        "include_response": True,
        "is_do_alignment": False
    }
    
    eval_config = {
        "checkpoints_dir": "/tmp/test_checkpoints",
        "max_new_tokens": 1024
    }
    
    try:
        # This will fail at model loading, but we can test the initialization logic
        pipeline = TwoStageRosetta(
            context_model_path="Qwen/Qwen3-4B",
            rosetta_checkpoint_dir="/tmp/test_checkpoints",  # Dummy checkpoint path
            rosetta_subfolder="final",
            device="cpu",  # Use CPU for testing
            max_new_tokens=512
        )
        print("✓ TwoStageRosetta initialization structure is correct")
        return True
    except Exception as e:
        if "checkpoints_dir" in str(e) or "model" in str(e).lower() or "config.json" in str(e):
            print("✓ TwoStageRosetta initialization structure is correct (expected model loading error)")
            return True
        else:
            print(f"✗ Unexpected error during initialization: {e}")
            return False


def test_method_signatures():
    """Test that all required methods have correct signatures."""
    print("\nTesting method signatures...")
    
    # Check if TwoStageRosetta has all required methods
    required_methods = [
        'generate',
        'process', 
        'answer_with_context',
        'get_background_context',
        '_prepare_rosetta_inputs'
    ]
    
    for method_name in required_methods:
        if hasattr(TwoStageRosetta, method_name):
            print(f"✓ Method '{method_name}' exists")
        else:
            print(f"✗ Method '{method_name}' missing")
            return False
    
    return True


def test_inheritance():
    """Test that TwoStageRosetta properly inherits from TwoStageInference."""
    print("\nTesting inheritance...")
    
    # Check inheritance
    from rosetta.baseline.multi_stage import TwoStageInference
    
    if issubclass(TwoStageRosetta, TwoStageInference):
        print("✓ TwoStageRosetta properly inherits from TwoStageInference")
        return True
    else:
        print("✗ TwoStageRosetta does not inherit from TwoStageInference")
        return False


def test_rosetta_specific_methods():
    """Test that Rosetta-specific methods are implemented."""
    print("\nTesting Rosetta-specific methods...")
    
    # Check for Rosetta-specific attributes
    rosetta_attrs = [
        'rosetta_model',
        'rosetta_tokenizer', 
        'llm_tokenizer',
        'rosetta_config',
        'eval_config'
    ]
    
    # Create a mock instance to check attributes
    class MockTwoStageRosetta(TwoStageRosetta):
        def __init__(self):
            # Skip parent initialization
            self.rosetta_config = {}
            self.eval_config = {}
            self.device = "cpu"
            self.max_new_tokens = 1024
            self.background_prompt = "test"
    
    try:
        mock_instance = MockTwoStageRosetta()
        
        # Check if _prepare_rosetta_inputs method exists and is callable
        if hasattr(mock_instance, '_prepare_rosetta_inputs'):
            print("✓ _prepare_rosetta_inputs method exists")
        else:
            print("✗ _prepare_rosetta_inputs method missing")
            return False
            
        return True
    except Exception as e:
        print(f"✗ Error testing Rosetta-specific methods: {e}")
        return False


def main():
    """Run all tests."""
    print("Running TwoStageRosetta tests...")
    print("=" * 50)
    
    tests = [
        test_inheritance,
        test_method_signatures,
        test_rosetta_specific_methods,
        test_initialization
    ]
    
    passed = 0
    total = len(tests)
    
    for test in tests:
        try:
            if test():
                passed += 1
        except Exception as e:
            print(f"✗ Test {test.__name__} failed with error: {e}")
    
    print("\n" + "=" * 50)
    print(f"Test Results: {passed}/{total} tests passed")
    
    if passed == total:
        print("✓ All tests passed! TwoStageRosetta implementation looks good.")
    else:
        print("✗ Some tests failed. Please check the implementation.")
    
    return passed == total


if __name__ == "__main__":
    success = main()
    sys.exit(0 if success else 1)
