#!/usr/bin/env python3
"""
Simple test script to verify NPT implementation works
"""

import torch
import os
import sys
from yacs.config import CfgNode as CN

# Test basic imports
def test_imports():
    print("Testing imports...")
    try:
        from npt_models import NPTCustomCLIP, NPTPromptLearner, extract_attention_weights
        print("✓ NPT models imported successfully")
    except Exception as e:
        print(f"✗ NPT import error: {e}")
        return False

    try:
        from configs.implemention import get_cfg_default
        print("✓ Configuration imports successful")
    except Exception as e:
        print(f"✗ Config import error: {e}")
        return False

    try:
        from trainers import build_optimizer, build_lr_scheduler
        print("✓ Trainer imports successful")
    except Exception as e:
        print(f"✗ Trainer import error: {e}")
        return False

    return True

def test_npt_prompt_learner():
    print("\nTesting NPT PromptLearner...")
    try:
        from npt_models import NPTPromptLearner
        from utils.train_util import load_clip_to_cpu
        from configs.implemention import get_cfg_default
        
        # Setup minimal config
        cfg = get_cfg_default()
        cfg.TRAINER = CN()
        cfg.TRAINER.LOCOOP = CN()
        cfg.TRAINER.LOCOOP.N_CTX = 16
        cfg.TRAINER.LOCOOP.CSC = False
        cfg.TRAINER.LOCOOP.CTX_INIT = ""
        cfg.TRAINER.LOCOOP.CLASS_TOKEN_POSITION = "end"
        cfg.INPUT = CN()
        cfg.INPUT.SIZE = [224, 224]
        cfg.MODEL = CN()
        cfg.MODEL.BACKBONE = CN()
        cfg.MODEL.BACKBONE.NAME = "ViT-B/16"
        
        # Load CLIP model
        clip_model = load_clip_to_cpu(cfg)
        
        # Test NPT PromptLearner
        classnames = ["class1", "class2", "class3"]  # dummy classes
        prompt_learner = NPTPromptLearner(cfg, classnames, clip_model)
        
        # Test forward pass
        prompts = prompt_learner.forward()
        print(f"✓ NPT PromptLearner created prompts with shape: {prompts.shape}")
        print(f"✓ Expected shape: ({len(classnames) + 1}, seq_len, dim) - includes nuisance prompt")
        
        # Test individual methods
        class_prompts = prompt_learner.get_class_prompts()
        nuisance_prompt = prompt_learner.get_nuisance_prompt()
        
        print(f"✓ Class prompts shape: {class_prompts.shape}")
        print(f"✓ Nuisance prompt shape: {nuisance_prompt.shape}")
        
        return True
        
    except Exception as e:
        print(f"✗ NPT PromptLearner test failed: {e}")
        return False

def test_attention_extraction():
    print("\nTesting attention extraction...")
    try:
        from npt_models import extract_attention_weights
        from utils.train_util import load_clip_to_cpu
        from configs.implemention import get_cfg_default
        
        # Setup minimal config
        cfg = get_cfg_default()
        cfg.MODEL = CN()
        cfg.MODEL.BACKBONE = CN()
        cfg.MODEL.BACKBONE.NAME = "ViT-B/16"
        
        # Load CLIP model
        clip_model = load_clip_to_cpu(cfg)
        
        # Create dummy image input
        batch_size = 2
        image_input = torch.randn(batch_size, 3, 224, 224)
        
        # Test attention extraction
        background_weights = extract_attention_weights(clip_model.visual, image_input)
        print(f"✓ Attention extraction successful, weights shape: {background_weights.shape}")
        print(f"✓ Expected shape: ({batch_size}, num_patches)")
        print(f"✓ Weights range: [{background_weights.min().item():.3f}, {background_weights.max().item():.3f}]")
        
        return True
        
    except Exception as e:
        print(f"✗ Attention extraction test failed: {e}")
        return False

def main():
    print("NPT Implementation Test")
    print("=" * 30)
    
    # Test 1: Imports
    if not test_imports():
        print("\n❌ Import tests failed. Check dependencies.")
        return False
    
    # Test 2: NPT PromptLearner
    if not test_npt_prompt_learner():
        print("\n❌ NPT PromptLearner tests failed.")
        return False
    
    # Test 3: Attention extraction
    if not test_attention_extraction():
        print("\n❌ Attention extraction tests failed.")
        return False
    
    print("\n" + "=" * 30)
    print("✅ All NPT tests passed successfully!")
    print("The implementation is ready for training.")
    return True

if __name__ == "__main__":
    success = main()
    sys.exit(0 if success else 1)