#!/usr/bin/env python3
"""
HEdit Installation Validation Script

This script validates that HEdit is properly installed and all components work.
Run this after installation to ensure everything is set up correctly.
"""

import sys
import importlib.util

def check_python_version():
    """Check if Python version is >= 3.8"""
    print("Checking Python version...")
    version = sys.version_info
    if version.major < 3 or (version.major == 3 and version.minor < 8):
        print(f"❌ Python {version.major}.{version.minor} detected. Python >= 3.8 required.")
        return False
    print(f"✅ Python {version.major}.{version.minor}.{version.micro} (OK)")
    return True

def check_dependency(name, package=None):
    """Check if a Python package is installed"""
    package = package or name
    spec = importlib.util.find_spec(package)
    if spec is None:
        print(f"❌ {name} not found")
        return False
    print(f"✅ {name} installed")
    return True

def check_dependencies():
    """Check all required dependencies"""
    print("\nChecking dependencies...")
    deps = [
        ("PyTorch", "torch"),
        ("Transformers", "transformers"),
        ("NumPy", "numpy"),
        ("SciPy", "scipy"),
        ("scikit-learn", "sklearn"),
        ("tqdm", "tqdm"),
        ("matplotlib", "matplotlib"),
    ]
    
    all_ok = True
    for name, package in deps:
        if not check_dependency(name, package):
            all_ok = False
    
    return all_ok

def check_hedit_package():
    """Check if HEdit package can be imported"""
    print("\nChecking HEdit package...")
    try:
        import hedit
        print(f"✅ HEdit package found (version {hedit.__version__})")
        
        # Check individual modules
        modules = [
            ("AnchorTokenDetector", hedit.AnchorTokenDetector),
            ("TriggerTokenDetector", hedit.TriggerTokenDetector),
            ("DeltaKVPredictor", hedit.DeltaKVPredictor),
            ("KVCorrectionMLP", hedit.KVCorrectionMLP),
            ("MLPTrainer", hedit.MLPTrainer),
        ]
        
        for name, obj in modules:
            if obj is not None:
                print(f"  ✅ {name} available")
            else:
                print(f"  ❌ {name} not available")
                return False
        
        return True
        
    except ImportError as e:
        print(f"❌ Cannot import HEdit: {e}")
        print("\nTry running: pip install -e .")
        return False

def check_cuda():
    """Check CUDA availability"""
    print("\nChecking CUDA...")
    try:
        import torch
        if torch.cuda.is_available():
            print(f"✅ CUDA available (GPU: {torch.cuda.get_device_name(0)})")
            print(f"   CUDA version: {torch.version.cuda}")
        else:
            print("⚠️  CUDA not available (CPU-only mode)")
        return True
    except Exception as e:
        print(f"❌ Error checking CUDA: {e}")
        return False

def check_files():
    """Check if important files exist"""
    print("\nChecking important files...")
    from pathlib import Path
    
    files = [
        "README.md",
        "requirements.txt",
        "setup.py",
        "hedit/__init__.py",
        "hedit/anchor_detector.py",
        "hedit/trigger_detector.py",
        "hedit/kv_predictor.py",
        "hedit/trainer.py",
        "examples/demo_training.py",
        "examples/demo_inference.py",
    ]
    
    all_ok = True
    for file in files:
        path = Path(file)
        if path.exists():
            print(f"  ✅ {file}")
        else:
            print(f"  ❌ {file} not found")
            all_ok = False
    
    return all_ok

def run_basic_test():
    """Run a basic functionality test"""
    print("\nRunning basic functionality test...")
    try:
        import torch
        from hedit import DeltaKVPredictor
        
        # Create a small model
        model = DeltaKVPredictor(
            hidden_dim=64,
            kv_dim=16,
            mlp_hidden_dim=32
        )
        
        # Test forward pass
        hidden_state = torch.randn(2, 64)
        anchor_k = torch.randn(2, 16)
        anchor_v = torch.randn(2, 16)
        
        delta_k, delta_v = model(hidden_state, anchor_k, anchor_v)
        
        if delta_k.shape == (2, 16) and delta_v.shape == (2, 16):
            print("✅ Basic functionality test passed")
            return True
        else:
            print("❌ Basic functionality test failed: incorrect output shapes")
            return False
            
    except Exception as e:
        print(f"❌ Basic functionality test failed: {e}")
        return False

def main():
    """Main validation function"""
    print("="*60)
    print("HEdit Installation Validation")
    print("="*60)
    
    checks = [
        ("Python version", check_python_version),
        ("Dependencies", check_dependencies),
        ("HEdit package", check_hedit_package),
        ("CUDA", check_cuda),
        ("Project files", check_files),
        ("Basic functionality", run_basic_test),
    ]
    
    results = []
    for name, check_func in checks:
        try:
            result = check_func()
            results.append((name, result))
        except Exception as e:
            print(f"❌ Error during {name} check: {e}")
            results.append((name, False))
    
    # Summary
    print("\n" + "="*60)
    print("Validation Summary")
    print("="*60)
    
    passed = sum(1 for _, result in results if result)
    total = len(results)
    
    for name, result in results:
        status = "✅ PASS" if result else "❌ FAIL"
        print(f"{status}: {name}")
    
    print(f"\nTotal: {passed}/{total} checks passed")
    
    if passed == total:
        print("\n🎉 All checks passed! HEdit is ready to use.")
        return 0
    else:
        print("\n⚠️  Some checks failed. Please review the output above.")
        return 1

if __name__ == "__main__":
    sys.exit(main())
