import torch
from torch import nn
from criteria.parse_related_loss.unet import unet
from models.third_party.BiSeNet import FaceParser
from mapper.attribute_list import ATTRIBUTE_LIST

class BackgroundLoss(nn.Module):
    def __init__(self, opts):
        super(BackgroundLoss, self).__init__()
        print('Loading UNet for Background Loss')
        self.opts = opts
        self.parsenet2 = unet()
        self.parsenet2.load_state_dict(torch.load(opts.parsenet_weights))
        self.parsenet2.eval()
        
        self.parsenet = FaceParser(model_path=opts.face_parser_ckpt)
        self.bg_mask_l2_loss = torch.nn.MSELoss(reduction='none')
        # self.shrink = torch.nn.AdaptiveAvgPool2d((512, 512))
        # self.magnify = torch.nn.AdaptiveAvgPool2d((1024, 1024))
         

    def gen_bg_mask(self, input_image, selected_attributes):
        # input_image = self.shrink(input_image)
        try:
            regions = ATTRIBUTE_LIST[selected_attributes[0]]['regions']
            labels_predict = self.parsenet.batch_run(input_image, pre_normalize=True, image_repr=False, compact_mask=True).detach()

            logits = torch.unsqueeze(torch.max(labels_predict, 1)[1], 1)
            # mask_bg = ((logits==0) | (logits==9)).float()  # background+cloth area
            mask_bg = (logits==0).float()  # background area
        except:
            regions = ATTRIBUTE_LIST[selected_attributes[0]]['regions2']
            labels_predict = self.parsenet2(input_image).detach()
            
            logits = torch.unsqueeze(torch.max(labels_predict, 1)[1], 1)
            # mask_bg = ((logits==0) | (logits==12)).float()  # background+cloth area
            mask_bg = (logits==0).float()  # background area
            
        mask_512 = torch.zeros_like(logits).to(self.opts.device)
        for i in regions:
            mask_512 |= (logits == i)
        mask_512 = 1 - mask_512
        mask_add = mask_512 - mask_bg
        # mask_512 = 1 - labels_predict[:, 10]
        # mask_1024 = self.magnify(mask_512.to(torch.float32))
        # mask_add_1024 = self.magnify(mask_add.to(torch.float32))
        return mask_512, mask_add
        
        # return mask_512, mask_add

    def forward(self, x, x_hat, selected_attributes):
        x_bg_mask, x_add_mask = self.gen_bg_mask(x, selected_attributes)
        x_hat_bg_mask, x_hat_add_mask = self.gen_bg_mask(x_hat, selected_attributes)
        bg_mask = ((x_bg_mask + x_hat_bg_mask) == 2).float()
        add_mask = ((x_add_mask + x_hat_add_mask) == 2).float()
        loss_bg = self.bg_mask_l2_loss(x * bg_mask, x_hat * bg_mask).reshape((-1,))
        loss_add = self.bg_mask_l2_loss(x * add_mask, x_hat * add_mask).reshape((-1,))
        # loss = torch.sort(loss, descending=True)[0][:loss.shape[0]].mean()
        # if(ATTRIBUTE_LIST[selected_attributes[0]]['attr']=='glasses' or ATTRIBUTE_LIST[selected_attributes[0]]['attr']=='hairstyle'):
        #     return loss_bg.mean() + loss_add.mean() * 1.5
        
        return loss_bg.mean() + loss_add.mean()
