import torch
import torch.nn.functional as F
from lrp import LRPCalculator

class ConEx:
    def __init__(self, model, cavs, config):
        self.model = model
        self.cavs = cavs
        self.config = config
        self.lrp = LRPCalculator(model)

    def get_cwv(self, cav_vector):
        return cav_vector

    def generate_concept_map(self, image, concept_name, latent_map):
       
        cav_data = self.cavs[concept_name]
        cav_vec = cav_data['vector']     
        mu_pos = cav_data['mu_pos']   
  
        weight = cav_vec.view(1, -1, 1, 1) 
        
        raw_map = F.conv2d(latent_map, weight) 
        raw_map = F.relu(raw_map)
        
        nv_k = F.relu(torch.dot(cav_vec, mu_pos)).max()
        
        epsilon = 1e-8
        normalized_map = torch.min(
            torch.tensor(1.0).to(raw_map.device), 
            raw_map / (nv_k + epsilon)
        )
        
        return normalized_map

    def get_concept_importance_map(self, image, target_class, concept_map, concept_name):
   
        cav_vec = self.cavs[concept_name]['vector']
        
        lrp_maps = self.lrp.compute_relevance(image, target_class)
        
        relu_cwv = F.relu(cav_vec).view(1, -1, 1, 1)
        
        relevance_component = (relu_cwv * lrp_maps).sum(dim=1, keepdim=True) 
        

        cim = concept_map * relevance_component
        
        return cim

    def compute_class_specific_attribution(self, image, concept_map, target_class):
       
        with torch.no_grad():
            logits_orig, _ = self.model(image)
            prob_orig = F.softmax(logits_orig, dim=1)[0, target_class]
        
  
        tau = concept_map.mean()
        binary_mask = (concept_map > tau).float()
        

        keep_mask = 1.0 - binary_mask
     
        masked_img = image * keep_mask
        
        with torch.no_grad():
            logits_masked, _ = self.model(masked_img)
            prob_masked = F.softmax(logits_masked, dim=1)[0, target_class]
            
        epsilon = 1e-8
        w_k = (prob_orig - prob_masked) / (prob_orig + epsilon)
        
        return w_k

    def explain(self, image, target_class, concepts_list):

        image = image.to(self.config.DEVICE)
        
        with torch.no_grad():
            _, latent_map = self.model(image)
            
        final_map = torch.zeros_like(image[:, 0:1, :, :])
        if latent_map.shape[-2:] != final_map.shape[-2:]:
             final_map = torch.zeros((1, 1, latent_map.shape[2], latent_map.shape[3])).to(image.device)

        concept_contributions = {}

        for concept in concepts_list:
            if concept not in self.cavs:
                continue
                

            m_k = self.generate_concept_map(image, concept, latent_map)
            

            cim = self.get_concept_importance_map(image, target_class, m_k, concept)
            
            m_k_upsampled = F.interpolate(m_k, size=image.shape[-2:], mode='bilinear')
            w_k = self.compute_class_specific_attribution(image, m_k_upsampled, target_class)
            
            final_map += w_k * cim
            concept_contributions[concept] = w_k.item()
            
        return final_map, concept_contributions