import torch.nn as nn
import torch.nn.functional as F
import torch

class GradCAM:
    def __init__(self, **kwargs):
        pass
    
    def _register_model(self, model: nn.Module, layer_name: str):
        self.data_hidden = []
        # register the penultimate layer
        def hook(module, input, output):
            self.data_hidden.append(output)
        layer: nn.Module = eval(f"model.{layer_name}")
        self.handler = layer.register_forward_hook(hook)

    def __call__(self, model, img, layer_name):

        batch_size, channel, height, width = img.size()
        self._register_model(model, layer_name)

        output = model(img)
        pred = torch.argmax(output, dim=-1)

        score = output[torch.arange(batch_size), pred] 
        data_hidden = self.data_hidden[0]
        grads = torch.autograd.grad(score.sum(), data_hidden)[0] 
        weights = grads.mean(dim=(2, 3)) 
        cam =  torch.einsum("ij,ijml->iml", weights, data_hidden) 
        cam = cam.clamp(min=0).unsqueeze(1)
        grad_cam = F.interpolate(cam, (height, width), mode="bilinear").squeeze(1)
        grad_cam = grad_cam.view(batch_size, height*width)
        max_value = grad_cam.max(dim=-1)[0].view(batch_size, 1)
        min_value = grad_cam.min(dim=-1)[0].view(batch_size, 1)
        grad_cam = (grad_cam - min_value) / (max_value - min_value + 1e-6)
        grad_cam = grad_cam.view(batch_size, height, width)
        return grad_cam
