import torch
import math
import torch.nn.functional as F
from torch.autograd import Variable
import cv2
import numpy as np
import matplotlib.pyplot as plt
import torch.nn as nn

import torch



def compute_zest(attn_maps_mid, attn_maps_up, attn_maps_down, attn_self, bboxes, object_positions, iter=None, attn_weight=None):
    
    loss = 0
    object_number = len(bboxes)

    if object_number == 0:
        return torch.tensor(0).float().cuda() if torch.cuda.is_available() else torch.tensor(0).float()

    attn16_list = []
    for attn_map_integrated in attn_maps_up[0]:
        attn16_list.append(attn_map_integrated)
        
    for attn_map_integrated in attn_maps_down[-1]:
        attn16_list.append(attn_map_integrated)
    
    
    attn_all_list = []
    attn_edge = []
    cri = nn.BCELoss().cuda()
    
    for sub_list in attn_maps_up:
        for item in sub_list:
            b, i, j = item.shape
            sub_res = int(math.sqrt(i))
            item = item.reshape(b, sub_res, sub_res, j).permute(3, 0, 1, 2).mean(dim=1, keepdim=True)
            if sub_res <= 64:
                attn_all_list.append(F.interpolate(item, 64, mode='bilinear'))
                attn_edge.append(F.interpolate(item, 64, mode='bilinear'))
            

    
    for sub_list in attn_maps_down:
        for item in sub_list:
            b, i, j = item.shape
            sub_res = int(math.sqrt(i))
            item = item.reshape(b, sub_res, sub_res, j).permute(3, 0, 1, 2).mean(dim=1, keepdim=True)
            if sub_res <= 64:
                attn_all_list.append(F.interpolate(item, 64, mode='bilinear'))

    
    for item in attn_maps_mid:
        b, i, j = item.shape
        sub_res = int(math.sqrt(i))
        item = item.reshape(b, sub_res, sub_res, j).permute(3, 0, 1, 2).mean(dim=1, keepdim=True)
        attn_all_list.append(F.interpolate(item, 64, mode='bilinear'))
        attn_edge.append(F.interpolate(item, 64, mode='bilinear'))
        
    
    attn_edge = torch.cat(attn_edge, dim=1)
    attn_edge = attn_edge.mean(dim=1).permute(1,2,0)
    attn_edge = torch.nn.functional.softmax(attn_edge[:, :, 1:]*120, dim=-1)
    
    attn_all_list = torch.cat(attn_all_list, dim=1)
    attn_all_list = attn_all_list.mean(dim=1).permute(1,2,0)
    
    attn_all_list_raw = attn_all_list[:, :, 1:]
    attn_all_list = torch.nn.functional.softmax(attn_all_list[:, :, 1:]*90, dim=-1)   
    
    H= W = 64

    obj_loss = 0
    
    rows, cols = torch.meshgrid(torch.arange(H), torch.arange(W))
    positions = torch.stack([rows.flatten(), cols.flatten()], dim=-1)
    positions = positions.to(attn_all_list.device) / H

    global_iou = 0
    global_count = 0
    
    for obj_idx in range(object_number):

        for num, obj_position in enumerate(object_positions[obj_idx]):
            true_obj_position = obj_position - 1
            if num == 0:
                att_map_obj = attn_all_list[:, :, true_obj_position] #+ 1e-7
                att_map_obj_raw = attn_all_list_raw[:, :, true_obj_position]
                att_map_edge = attn_edge[:, :, true_obj_position]
                is_this_cls = torch.max(attn_all_list, dim=-1)[0] == att_map_obj

            else:
                att_map_obj = att_map_obj + attn_all_list[:, :, true_obj_position] #+ 1e-7
                att_map_obj_raw = att_map_obj_raw + attn_all_list_raw[:, :, true_obj_position]
                att_map_edge = att_map_edge + attn_edge[:, :, true_obj_position]
                is_this_cls = torch.max((attn_all_list.max(dim=-1)[0] == att_map_obj) * 1.0, is_this_cls)

        
        a_norm = (att_map_obj_raw - att_map_obj_raw.min()) / (att_map_obj_raw.max() - att_map_obj_raw.min())


        mask = torch.zeros(size=(H, W)).cuda() if torch.cuda.is_available() else torch.zeros(size=(H, W))
        mask_clone = mask.clone()
        
        for obj_box in bboxes[obj_idx]:
            # print(obj_idx, obj_box)
            x_min, y_min, x_max, y_max = int(obj_box[0] * W), \
            int(obj_box[1] * H), int(obj_box[2] * W), int(obj_box[3] * H)
            mask[y_min: y_max, x_min: x_max] = 1

        obj_loss += cri(att_map_obj_raw.view(-1), mask.view(-1)) + cri(a_norm.view(-1), mask.view(-1))


            
    loss += obj_loss / object_number
    
    return loss, 0, 0