#!/usr/bin/env python3
"""
Test script for RAM++ ADE20K model

This script tests the model initialization and basic functionality
without requiring a full training environment.
"""

import sys
import os

# Add paths
sys.path.append(os.path.dirname(os.path.abspath(__file__)))
sys.path.append(os.path.dirname(os.path.dirname(os.path.abspath(__file__))))

def test_model_import():
    """Test if we can import the model"""
    try:
        from models.ram_plus_ade20k import RAM_plus_ADE20K, load_ram_plus_ade20k_pretrained
        print("✅ Successfully imported RAM++ ADE20K model")
        return True
    except Exception as e:
        print(f"❌ Failed to import model: {e}")
        return False

def test_model_creation():
    """Test model creation without RAM++ backbone"""
    try:
        from models.ram_plus_ade20k import RAM_plus_ADE20K
        
        # Create model without RAM++ backbone for testing
        model = RAM_plus_ADE20K(ram_plus_model=None)
        print(f"✅ Model created successfully")
        print(f"   - Number of ADE20K classes: {model.num_classes}")
        print(f"   - Adapter architecture: {len(model.adapter)} layers")
        return True
    except Exception as e:
        print(f"❌ Failed to create model: {e}")
        return False

def test_adapter():
    """Test adapter functionality"""
    try:
        import torch
        from models.ram_plus_ade20k import RAM_plus_ADE20K
        
        model = RAM_plus_ADE20K(ram_plus_model=None)
        
        # Test adapter with dummy RAM++ logits
        dummy_ram_logits = torch.randn(2, 4584)  # Batch size 2
        ade20k_logits = model.adapter(dummy_ram_logits)
        
        print(f"✅ Adapter test passed")
        print(f"   - Input shape: {dummy_ram_logits.shape}")
        print(f"   - Output shape: {ade20k_logits.shape}")
        print(f"   - Expected output shape: (2, 150)")
        
        assert ade20k_logits.shape == (2, 150), f"Wrong output shape: {ade20k_logits.shape}"
        return True
    except Exception as e:
        print(f"❌ Adapter test failed: {e}")
        return False

def test_training_script():
    """Test if training script can be imported"""
    try:
        from training.train_ade20k import calculate_ade20k_metrics
        print("✅ Successfully imported training script")
        return True
    except Exception as e:
        print(f"❌ Failed to import training script: {e}")
        return False

def main():
    print("🔍 Testing RAM++ ADE20K implementation...")
    print("=" * 50)
    
    tests = [
        ("Model Import", test_model_import),
        ("Model Creation", test_model_creation), 
        ("Adapter Functionality", test_adapter),
        ("Training Script Import", test_training_script)
    ]
    
    results = []
    for test_name, test_func in tests:
        print(f"\n📋 Running: {test_name}")
        try:
            result = test_func()
            results.append(result)
        except Exception as e:
            print(f"❌ {test_name} failed with exception: {e}")
            results.append(False)
    
    print("\n" + "=" * 50)
    print("🎯 Test Results Summary:")
    for i, (test_name, _) in enumerate(tests):
        status = "✅ PASS" if results[i] else "❌ FAIL"
        print(f"   {test_name}: {status}")
    
    total_passed = sum(results)
    print(f"\n📊 Overall: {total_passed}/{len(tests)} tests passed")
    
    if total_passed == len(tests):
        print("🎉 All tests passed! Your implementation is ready.")
    else:
        print("⚠️  Some tests failed. Check the error messages above.")
    
    return total_passed == len(tests)

if __name__ == "__main__":
    success = main()
    sys.exit(0 if success else 1)