import torch
import torch.nn.functional as F
from cam.basecam import BaseCAM
from sklearn.cluster import KMeans
from pytorch_grad_cam.utils.model_targets import ClassifierOutputTarget

class ClusterScoreCAM(BaseCAM):
    """
    Score-CAM with clustering by K-mean++, dropout, and temperature softmax.
    
    Args:
        model_dict: inherited from BaseCAM
        num_clusters: the number of clusters: K
        zero_ratio: the percentage of dropped clusters
        temperature_dict: dict class_idx, temperature
        default_temperature: the default_temperature
    """
    def __init__(
        self,
        model_dict,
        num_clusters=10, #default vaulue
        zero_ratio=0.5, #default vaulue
        temperature_dict=None,
        temperature=0.5 #default vaulue
    ):
        super().__init__(model_dict)
        self.K = num_clusters
        self.zero_ratio = zero_ratio
        self.temperature_dict = temperature_dict or {}
        self.temperature = temperature
        self.model = model_dict['arch']

    def forward(self, input, class_idx=None, retain_graph=False):
        # Input: (1,C,H,W)
        b, c, h, w = input.size()
        
        
        # Forward pass + select class
        logits = self.model_arch(input)
        if class_idx is None:
            class_idx = logits.argmax(dim=1).item()
        elif isinstance(class_idx, torch.Tensor): # if class_idx is tensor, then turn it to be int
            class_idx = int(class_idx)
        base_score = logits[0, class_idx]

        #  Backprop to get activation maps (low-res)
        self.model_arch.zero_grad()
        base_score.backward(retain_graph=retain_graph)
        activations = self.activations['value'][0]  # (nc, u, v)
        nc, u, v = activations.shape

        # Upsample & normalize each activation map on input size
        up_maps = []
        for i in range(nc):
            a = activations[i:i+1].unsqueeze(0)  # (1,1,u,v)
            a_up = F.interpolate(
                a, size=(h, w), mode='bilinear', align_corners=False
            )[0, 0]
            if a_up.max() != a_up.min():
                a_up = (a_up - a_up.min()) / (a_up.max() - a_up.min())
            up_maps.append(a_up)
        up_maps = torch.stack(up_maps, dim=0)  # (nc, h, w)

        # Flatten upsampled maps and clustering with k-mean++
        flat_maps = up_maps.reshape(nc, -1).detach().cpu().numpy()  # (nc, h*w)
        kmeans = KMeans(n_clusters=self.K, init='k-means++', random_state=0)
        print(f"[ClusterScoreCAM] Running KMeans++ with {self.K} clusters...")
        kmeans.fit(flat_maps)
        rep_maps = torch.from_numpy(
            kmeans.cluster_centers_.reshape(self.K, h, w)
        ).to(activations.device)
        
       
        self.rep_maps = rep_maps  #for visualize   # tensor (K, h, w)
        self.base_score = base_score  # for debug

        # Compute score difference for each mask
        diffs = torch.zeros(self.K, device=activations.device)
        with torch.no_grad():
            for k in range(self.K):
                mask = rep_maps[k:k+1].unsqueeze(0)  # (1,1,h,w)
                out = self.model_arch(input * mask)
                diffs[k] = out[0, class_idx] - base_score

        # Drop-out noisy clusters
        num_zero = int(self.zero_ratio * self.K)
        if num_zero > 0:
            lowest = torch.argsort(diffs)[:num_zero]
            diffs[lowest] = float('-inf')

        # Apply temperature softmax
        T = self.temperature_dict.get(class_idx, self.temperature)
        weights = F.softmax(diffs / T, dim=0)

        # Generate saliency map
        sal = torch.zeros(1,1,h,w, device=activations.device)
        for k in range(self.K):
            sal += weights[k] * rep_maps[k:k+1].unsqueeze(0)

        # Post-process + normalize
        sal = F.relu(sal)
        mn, mx = sal.min(), sal.max()
        if mn == mx:
            return None
        sal = (sal - mn) / (mx - mn)
        
        self.last_saliency = sal  # tensor (1,1,h,w)

        return sal
    
    def __call__(self,
                input_tensor: torch.Tensor,
                targets: list[ClassifierOutputTarget] | None = None,
                class_idx: int | None = None,
                retain_graph: bool = False):
        
        if targets is not None and len(targets) > 0 and isinstance(targets[0], ClassifierOutputTarget):
            class_idx = targets[0].category
            
        return self.forward(input_tensor, class_idx, retain_graph)
