import torch
from pytorch_grad_cam.base_cam import BaseCAM
from pytorch_grad_cam.utils.image import scale_cam_image
import numpy as np

"""
This is our implementation of the OptiCAM paper (https://arxiv.org/abs/2301.07002)
"""

class OptiCAM(BaseCAM):
    def __init__(self, model, target_layers, use_cuda=False,
                 reshape_transform=None):
        super(OptiCAM, self).__init__(model=model,
                                       target_layers=target_layers,
                                       reshape_transform=reshape_transform,
                                       uses_gradients=False)
        self.means = torch.from_numpy(np.array([0.485, 0.456, 0.406])).unsqueeze(0).unsqueeze(-1).unsqueeze(-1).float()
        self.stds = torch.from_numpy(np.array([0.229, 0.224, 0.225])).unsqueeze(0).unsqueeze(-1).unsqueeze(-1).float()
        self.loss = torch.nn.CrossEntropyLoss()
        
    def get_cam_image(self, input_tensor, target_layer, target_category, activation_batch, grads, eigen_smooth):
        w = torch.randn((activation_batch.shape[0],activation_batch.shape[1]), requires_grad=True, device=input_tensor.device)
        optim = torch.optim.Adam([w], lr=1e-1, maximize=True)
        activation_batch = torch.from_numpy(activation_batch).type_as(input_tensor)
        classes = torch.stack([target_category[i].category for i in range(len(target_category))])
        self.stds = self.stds.type_as(input_tensor)
        self.means = self.means.type_as(input_tensor)

        for epoch in range(100):
            cams = []
            for idx, feature_map in enumerate(activation_batch):
                weight = w[idx].softmax(dim=0).unsqueeze(-1).unsqueeze(-1)
                cam = (feature_map * weight).sum(0)
                cam = torch.nn.functional.interpolate(cam.unsqueeze(0).unsqueeze(0), (224,224))[0]
                min_, max_ = cam.min(), cam.max()
                cam = (cam-min_)/(max_-min_)
                cams.append(cam)
            cams = torch.stack(cams).type_as(input_tensor)
            masked_img = input_tensor * cams
            masked_img = (masked_img - self.means)/self.stds
            output = self.model(masked_img)#.softmax(-1)
            output = torch.stack([output[i,j] for i,j in enumerate(classes)])
            loss = output.sum()#self.loss(output, classes)
            if loss <= 1e-10:
                break
            loss.backward()
            optim.step()
            optim.zero_grad()
            

        cams = []
        for idx, feature_map in enumerate(activation_batch):
            weight = w[idx].softmax(dim=0).unsqueeze(-1).unsqueeze(-1)
            cam = (feature_map * weight).sum(0)
            cams.append(cam)
        cam = torch.stack(cams)
        cam = cam.cpu() if torch.cuda.is_available() else cam
        cam = cam.detach().numpy()
        return cam
    
    def compute_cam_per_layer(
            self,
            input_tensor: torch.Tensor,
            targets,
            eigen_smooth: bool) -> np.ndarray:
        activations_list = [a.cpu().data.numpy()
                            for a in self.activations_and_grads.activations]
        grads_list = [g.cpu().data.numpy()
                      for g in self.activations_and_grads.gradients]
        target_size = self.get_target_width_height(input_tensor)

        cam_per_target_layer = []
        # Loop over the saliency image from every layer
        for i in range(len(self.target_layers)):
            target_layer = self.target_layers[i]
            layer_activations = None
            layer_grads = None
            if i < len(activations_list):
                layer_activations = activations_list[i]
            if i < len(grads_list):
                layer_grads = grads_list[i]

            cam = self.get_cam_image(input_tensor,
                                     target_layer,
                                     targets,
                                     layer_activations,
                                     layer_grads,
                                     eigen_smooth)
            scaled = scale_cam_image(cam, target_size)
            cam_per_target_layer.append(scaled[:, None, :])

        return cam_per_target_layer