import torch
import torch.nn.functional as F
import numpy as np

def calculate_vcm(cav_vector, test_segments_embeddings):
    cav_norm = F.normalize(cav_vector.unsqueeze(0), dim=1)
    emb_norm = F.normalize(test_segments_embeddings, dim=1)
    
    cosine_sim = (cav_norm * emb_norm).sum(dim=1)
    return cosine_sim.mean().item()

def calculate_ccm(cav_vector, model, class_idx, validation_images):
    scores = []
    model.eval()
    model.zero_grad()
    
    cav_norm = F.normalize(cav_vector, dim=0)
    
    for img in validation_images:
        img.requires_grad = True
        logits, features = model(img.unsqueeze(0))
        
        if isinstance(features, dict):
            latent = features[list(features.keys())[0]]
        else:
            latent = features # ViT
            
        score = logits[0, class_idx]
        grads = torch.autograd.grad(score, latent)[0]
        
        if grads.dim() > 2:
            grads = grads.mean(dim=[2, 3])
            
        grads_norm = F.normalize(grads[0], dim=0)
        
        dot = torch.dot(cav_norm, grads_norm)
        scores.append(dot.item())
        
    return np.mean(scores)

def insertion_deletion_metric(model, image, saliency_map, step=0.05, mode='deletion'):
    """
    Faithfulness metrics (Insertion/Deletion).
    """
    B, C, H, W = image.shape
    saliency_flat = saliency_map.view(-1)
    
    _, indices = torch.sort(saliency_flat, descending=True)
    
    n_pixels = len(indices)
    n_steps = int(1 / step)
    
    scores = []
    
    img_mod = image.clone().view(1, C, -1)
   
    if mode == 'insertion':
        canvas = torch.zeros_like(img_mod) 
    else:
        canvas = img_mod.clone() # Start full
        
    with torch.no_grad():
        # Initial score
        logits, _ = model(canvas.view(B, C, H, W))
        probs = F.softmax(logits, dim=1)
        orig_cls = torch.argmax(probs)
        scores.append(probs[0, orig_cls].item())

        for i in range(1, n_steps):
            limit = int(i * step * n_pixels)
            current_indices = indices[:limit]
            
            if mode == 'insertion':
                canvas[:, :, current_indices] = img_mod[:, :, current_indices]
            else:
                canvas[:, :, current_indices] = 0.0
            
            logits, _ = model(canvas.view(B, C, H, W))
            probs = F.softmax(logits, dim=1)
            scores.append(probs[0, orig_cls].item())
            
    return np.trapz(scores, dx=1.0/n_steps) 