import torch
import numpy as np
from config import Config
from model_utils import MaskedResNet, MaskedViT
from segmentation_utils import ConceptDiscoverer, GroundedSAMWrapper
from cav import CAVManager
from conex import ConEx
from metrics import calculate_vcm, calculate_ccm, insertion_deletion_metric
from PIL import Image
import torchvision.transforms as T

def load_dummy_data():
    img = torch.rand(1, 3, 224, 224).to(Config.DEVICE)
    mask = torch.ones(1, 1, 224, 224).to(Config.DEVICE)
    mask[:, :, :50, :50] = 0
    return img, mask

def main():
    print("Initializing ConEx Framework...")
    
    model_type = 'resnet' 
    if model_type == 'resnet':
        model = MaskedResNet().to(Config.DEVICE)
    else:
        model = MaskedViT().to(Config.DEVICE)
        
    model.eval()
    print(f"Model {model_type} loaded.")

    class_name = "hornbill"
    discoverer = ConceptDiscoverer()
    concepts = discoverer.get_concepts_for_class(class_name)
    print(f"Initial Concepts for {class_name}: {concepts}")
    
    sam = GroundedSAMWrapper(Config)
    validated_concepts = sam.filter_concepts([], class_name, concepts)
    print(f"Validated Concepts: {validated_concepts}")

    cav_manager = CAVManager(model, Config)
    cavs = {}
    
    print("Constructing CAVs...")
    for concept in validated_concepts:
        pos_samples = [(torch.rand(3, 224, 224), torch.ones(1, 224, 224)) for _ in range(10)]
        neg_samples = [(torch.rand(3, 224, 224), torch.ones(1, 224, 224)) for _ in range(10)]
        
        cav_data = cav_manager.create_cav(concept, pos_samples, neg_samples)
        if cav_data:
            cavs[concept] = cav_data
            
    conex_engine = ConEx(model, cavs, Config)
    
    test_img, _ = load_dummy_data()
    target_class = 10 # arbitrary class index
    
    print("Generating Explanations...")
    saliency_map, contributions = conex_engine.explain(test_img, target_class, validated_concepts)
    
    print(f"Concept Contributions: {contributions}")
    print(f"Saliency Map Shape: {saliency_map.shape}")
    
    saliency_resized = torch.nn.functional.interpolate(saliency_map, size=(224, 224), mode='bilinear')
    del_score = insertion_deletion_metric(model, test_img, saliency_resized, mode='deletion')
    print(f"Deletion Score: {del_score:.4f}")

if __name__ == "__main__":
    main()