import torch
import torch.nn.functional as F
import numpy as np
import random

def find_yolo_layer(model, layer_name):
    """Find yolov5 layer to calculate GradCAM and GradCAM++

    Args:
        model: yolov5 model.
        layer_name (str): the name of layer with its hierarchical information.

    Return:
        target_layer: found layer
    """
    hierarchy = layer_name.split('_')
    target_layer = model.model._modules[hierarchy[0]]

    for h in hierarchy[1:]:
        target_layer = target_layer._modules[h]
    return target_layer

def get_2d_projection(activation_batch):
    # TBD: use pytorch batch svd implementation
    activation_batch = np.array(activation_batch)
    activation_batch[np.isnan(activation_batch)] = 0
    projections = []
    for activations in activation_batch:
        reshaped_activations = (activations).reshape(
            activations.shape[0], -1).transpose()
        # Centering before the SVD seems to be important here,
        # Otherwise the image returned is negative
        reshaped_activations = reshaped_activations - reshaped_activations.mean(axis=0)
        U, S, VT = np.linalg.svd(reshaped_activations, full_matrices=True)
        projection = reshaped_activations @ VT[0, :]
        projection = projection.reshape(activations.shape[1:])
        projections.append(projection)
    return np.float32(projections)

class YOLOV5LayerCAM:
    # 初始化，得到target_layer层
    def __init__(self, model, layer_name, layercam = False):
        # frame = inspect.currentframe()          # define a frame to track
        # self.gpu_tracker = MemTracker(frame)         # define a GPU tracker
        self.model = model
        self.gradients = dict()
        self.activations = dict()
        self.layercam = layercam

        def backward_hook(module, grad_input, grad_output):
            self.gradients['value'] = grad_output[0]
            return None

        def forward_hook(module, input, output):
            self.activations['value'] = output
            return None

        target_layer = find_yolo_layer(self.model, layer_name)
        # 获取forward过程中每层的输入和输出，用于对比hook是不是正确记录
        self.fh = target_layer.register_forward_hook(forward_hook)
        # target_layer.register_backward_hook(backward_hook)
        self.bh = target_layer.register_full_backward_hook(backward_hook)

    def forward(self, input_img, class_idx=True):
        """
        Args:
            input_img: input image with shape of (1, 3, H, W)
        Return:
            mask: saliency map of the same spatial dimension with input
            logit: model output
            preds: The object predictions
        """
        # self.gpu_tracker.track()
        # self.model(torch.zeros((1, 3, 640,640), device=self.model.device))
        # self.gpu_tracker.track()
        saliency_maps = []
        b, _, h, w = input_img.size()
        preds, logits = self.model(input_img)
        score = False
        for logit, cls, _ in zip(logits[0], preds[1][0], preds[2][0]):
            if class_idx:
                score = logit[cls]
                # score = -logit[0]
            else:
                score = logit.max()
            self.model.zero_grad()
            # 获取梯度
            score.backward(retain_graph=True)
            gradients = self.gradients['value']
            activations = self.activations['value']
            
            if self.layercam:
                weights = F.relu(gradients)  # layerCAM
            else:
                b, k, _, _ = gradients.size()
                alpha = gradients.view(b, k, -1).mean(2)
                weights = alpha.view(b, k, 1, 1)
            # weights = gradients #HiResCAM
            # weights = abs(gradients)
            saliency_map = (weights * activations).sum(1, keepdim=True)
            # saliency_map = torch.tensor(get_2d_projection((F.relu(weights * activations)).cpu().data), device = self.device).unsqueeze(1)
            saliency_map = F.relu(saliency_map)
            saliency_map = F.interpolate(saliency_map, size=(h, w), mode='bilinear', align_corners=False)
            saliency_map_min, saliency_map_max = saliency_map.min(), saliency_map.max()
            saliency_map = (saliency_map - saliency_map_min).div(max(saliency_map_max - saliency_map_min,1e-10)).data
            saliency_maps.append(saliency_map)      
        if score:
            torch.autograd.backward([score], retain_graph = False)
        return saliency_maps, logits, preds

    def __call__(self, input_img):
        return self.forward(input_img)



class YOLOV5CALayerCAM:
    # 初始化，得到target_layer层
    def __init__(self, model, layer_name, num_class = 80, target_cls = {}):
        # frame = inspect.currentframe()          # define a frame to track
        # self.gpu_tracker = MemTracker(frame)         # define a GPU tracker
        self.model = model
        self.gradients = dict()
        self.activations = dict()
        self.num_class = num_class
        self.target_cls = set(target_cls)

        def backward_hook(module, grad_input, grad_output):
            self.gradients['value'] = grad_output[0]
            return None

        def forward_hook(module, input, output):
            self.activations['value'] = output
            return None

        target_layer = find_yolo_layer(self.model, layer_name)
        # 获取forward过程中每层的输入和输出，用于对比hook是不是正确记录
        self.fh = target_layer.register_forward_hook(forward_hook)
        # target_layer.register_backward_hook(backward_hook)
        self.bh = target_layer.register_full_backward_hook(backward_hook)

    def forward(self, input_img):
        """
        Args:
            input_img: input image with shape of (1, 3, H, W)
        Return:
            mask: saliency map of the same spatial dimension with input
            logit: model output
            preds: The object predictions
        """
        # self.gpu_tracker.track()
        # self.model(torch.zeros((1, 3, 640,640), device=self.model.device))
        # self.gpu_tracker.track()
        saliency_maps = []
        b, _, h, w = input_img.size()
        preds, logits = self.model(input_img)
        score = False
        for logit, cls, _ in zip(logits[0], preds[1][0], preds[2][0]):
            if not self.target_cls:
                classes = set(range(self.num_class))
                classes_ = classes - {cls}
            else:
                classes_ = self.target_cls -{cls}
            sm = dict()
            classes_ = list(classes_)
            random.shuffle(classes_)
            for cls in classes_:
                score = -logit[cls]
                self.model.zero_grad()
                # 获取梯度
                score.backward(retain_graph=True)
                gradients = self.gradients['value']
                activations = self.activations['value']
                
                weights = F.relu(gradients)  # layerCAM
                saliency_map = (weights * activations).sum(1, keepdim=True)
                saliency_map = F.relu(saliency_map)
                saliency_map = F.interpolate(saliency_map, size=(h, w), mode='bilinear', align_corners=False)
                saliency_map_min, saliency_map_max = saliency_map.min(), saliency_map.max()
                saliency_map = (saliency_map - saliency_map_min).div(max(saliency_map_max - saliency_map_min,1e-10)).data
                sm[cls] = saliency_map
            saliency_maps.append(sm)      
        if score:
            torch.autograd.backward([score], retain_graph = False)
        return saliency_maps, logits, preds

    def __call__(self, input_img):
        return self.forward(input_img)