import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.autograd import Variable
import cv2
import numpy as np
from scipy import misc
from torchvision import transforms

class GradCam(object):

    def __init__(self, model, device, hw=32, ng=1, bs=32):
        self.model = model.eval()
        self.device = device
        self.hw = hw
        self.ng = ng
        self.bs = bs
        self.extractor = self.model
        self.resize = transforms.Resize(self.hw)

    def __call__(self, x):
        features, output = self.extractor(x)

        class_idx = F.one_hot(output.max(1).indices, num_classes=output.size(1))
        one_hot = torch.sum(class_idx * output)

        self.model.zero_grad()
        one_hot.backward()

        grads = self.extractor.get_gradients()[-3].data.cpu().numpy()
        targets = features[-3].data.cpu().numpy()
        if self.ng > 1:
            grads = grads.reshape(self.ng, self.bs, grads.shape[1], grads.shape[2], grads.shape[3]).mean(0)
            targets = targets.reshape(self.ng, self.bs, targets.shape[1], targets.shape[2], targets.shape[3]).mean(0)
        cams = []
        for t, g in zip(targets, grads):
            target = t
            weights = np.mean(g, axis=(1,2))
            cam = np.ones(target.shape[1:], dtype=np.float32)
            for i, w in enumerate(weights):
                cam += w * target[i, :, :]
            cam = np.maximum(cam, 0)
            cam = cv2.resize(cam, (self.hw, self.hw))
            cam = cam - np.min(cam)
            #print(np.max(cam))
            cam = cam / np.max(cam)
            cams.append(cam)
        self.model.zero_grad()
        return cams
