import torch, torch.nn.functional as F

class GradCAM:
    def __init__(self, model, target_layer):
        self.activations = None
        self.gradients   = None
        target_layer.register_forward_hook(
            lambda _, __, out: setattr(self, "activations", out))
        target_layer.register_full_backward_hook(
            lambda _, grad_in, grad_out: setattr(self, "gradients", grad_out[0]))

    def __call__(self, scores):

        grad = torch.autograd.grad(
            outputs=scores.sum(),
            inputs=self.activations,
            retain_graph=True,
            allow_unused=True
        )[0]

        if grad is None:
            grad = torch.zeros_like(self.activations)

        weights = grad.mean(dim=(2, 3), keepdim=True)
        cam = (weights * self.activations).sum(dim=1, keepdim=True)
        cam = F.relu(cam)

        cam -= cam.amin(dim=(2, 3), keepdim=True)
        cam /= cam.amax(dim=(2, 3), keepdim=True).clamp(min=1e-6)
        return cam.detach()
