import time
import torch
import torch.nn.functional as F


def find_yolo_layer(model, layer_name):
    hierarchy = layer_name.split('_')
    target_layer = model.model._modules[hierarchy[0]]

    for h in hierarchy[1:]:
        target_layer = target_layer._modules[h]
    return target_layer

class YOLOV3Attention:
    def __init__(self, model, target_layers, device, img_size=(640, 640)):
        self.device = device
        self.model = model
        self.gradients = []
        self.activations = []
        self.saliency_maps = []
        self.inf_result = []
        self.target_layer = dict()  

        def backward_hook(module, grad_input, grad_output):
            self.gradients.append(grad_output[0])
            return None

        def forward_hook(module, input, output):
            self.activations.append(output) 
            return None

        for i, layer_name in enumerate(target_layers):
            self.target_layer[layer_name] = find_yolo_layer(self.model, layer_name)
            self.target_layer[layer_name].register_forward_hook(forward_hook)
            self.target_layer[layer_name].register_full_backward_hook(backward_hook)

        self.model(torch.zeros(1, 3, *img_size, device=self.device))

    def forward(self, input_img, class_idx=True): 
        self.saliency_maps.clear()
        self.inf_result.clear()
        b, c, h, w = input_img.size()
        loss = torch.tensor(0.0).float().to(self.device)
        
        self.activations.clear() 
        preds, logits, train_out, feature = self.model(input_img)  
        for logit, box, cls, cls_name, conf in zip(logits[0], preds[0][0], preds[1][0], preds[2][0], preds[3][0]): 
            # filter out the attacked objects
            cx, cy = (box[0]+box[2]) // 2, (box[1]+box[3]) // 2 
            if cls_name != 'car' or cx < 280 or cy < 280 or cx > 360 or cy > 360:
                continue
            
            score = conf 

            self.model.zero_grad()
            self.gradients.clear()
            score.backward(retain_graph=True) 

            for layer_index, (gradients, activations) in enumerate(zip(self.gradients[::-1], self.activations)):
                if torch.max(gradients) == 0.0:
                    continue
                b, k, u, v = gradients.size() 
                loss += torch.sum(torch.abs(gradients * activations)).to(self.device)
          
                alpha = gradients.view(b, k, -1).mean(2)
                weights = alpha.view(b, k, 1, 1) 
                saliency_map = (weights * activations).sum(1, keepdim=True)  
                saliency_map = F.relu(saliency_map)
                saliency_map = F.interpolate(saliency_map, size=(h, w), mode='bilinear', align_corners=False)
                saliency_map_min, saliency_map_max = saliency_map.min(), saliency_map.max()
                saliency_map = (saliency_map - saliency_map_min).div(saliency_map_max - saliency_map_min).data
             
                self.saliency_maps.append(saliency_map)
                self.inf_result.append([box, cls_name, conf, layer_index])
    
        return self.saliency_maps, self.inf_result, train_out, feature, loss

    def __call__(self, input_img):
        return self.forward(input_img)

 