import torch
import torch.nn.functional as F
from torch import nn

from util import box_ops
from util.misc import (NestedTensor, get_world_size, is_dist_avail_and_initialized)

from .backbone import build_backbone

from .transformer import build_visual_encoder
from .decoder import build_vg_decoder
from pytorch_pretrained_bert.modeling import BertModel


class Highlighter(nn.Module):
    def __init__(self, pretrained_weights, args=None):
        """ Initializes the model."""
        super().__init__()

        # Image feature encoder (CNN + Transformer encoder)
        self.backbone = build_backbone(args)
        self.trans_encoder = build_visual_encoder(args)
        self.input_proj = nn.Conv2d(self.backbone.num_channels, self.trans_encoder.d_model, kernel_size=1)

        # Text feature encoder (BERT)
        self.bert = BertModel.from_pretrained(args.bert_model)
        self.bert_proj = nn.Linear(args.bert_output_dim, args.hidden_dim)
        self.aligned_proj = nn.Linear(args.bert_output_dim, args.hidden_dim)
        self.bert_output_layers = args.bert_output_layers
        for v in self.bert.pooler.parameters():
            v.requires_grad_(False)

        # visual grounding
        self.trans_decoder = build_vg_decoder(args)

        hidden_dim = self.trans_encoder.d_model
        self.bbox_embed = MLP(hidden_dim, hidden_dim, 4, 3)

        # if pretrained_weights:
        #     self.load_pretrained_weights(pretrained_weights)

    def load_pretrained_weights(self, weights_path):
        def load_weights(module, prefix, weights):
            module_keys = module.state_dict().keys()
            weights_keys = [k for k in weights.keys() if prefix in k]
            update_weights = dict()
            for k in module_keys:
                prefix_k = prefix+'.'+k
                if prefix_k in weights_keys:
                    update_weights[k] = weights[prefix_k]
                else:
                    print(f"Weights of {k} are not pre-loaded.")
            module.load_state_dict(update_weights, strict=False)

        weights = torch.load(weights_path, map_location='cpu')['model']
        load_weights(self.backbone, prefix='backbone', weights=weights)
        load_weights(self.trans_encoder, prefix='transformer', weights=weights)
        load_weights(self.input_proj, prefix='input_proj', weights=weights)


    def forward(self, image, image_mask, word_id, word_mask):

        N = image.size(0)

        # Text features
        word_feat, _ = self.bert(word_id, token_type_ids=None, attention_mask=word_mask)
        word_feat = torch.stack(word_feat[-self.bert_output_layers:], 1).mean(1)
        aligned_word_embedding = self.aligned_proj(word_feat).permute(1, 0, 2) # BxLxC -> LxBxC
        word_feat = self.bert_proj(word_feat)
        word_feat = word_feat.permute(1, 0, 2) # BxLxC -> LxBxC
        word_mask = ~word_mask
        aligned_info = (aligned_word_embedding, word_mask)
        # Image features
        features, pos, cor_map = self.backbone(NestedTensor(image, image_mask), aligned_info)
        src, mask = features[-1].decompose()
        assert mask is not None
        img_feat, mask, pos_embed, cor_map_tr = self.trans_encoder(self.input_proj(src), mask, pos[-1], aligned_info)


        hs = self.trans_decoder(img_feat, mask, pos_embed, word_feat, word_mask)

        outputs_coord = self.bbox_embed(hs).sigmoid()
        out = {'pred_boxes': outputs_coord[-1]}

        if self.training:
            cor_map_all = cor_map + cor_map_tr
            out['aux_outputs'] = [{'pred_boxes': b} for b in outputs_coord[:-1]]
            out['cor_maps'] = [{'maps': b} for b in cor_map_all]
        return out




class VGCriterion(nn.Module):
    """ This class computes the loss for Highlighter."""
    def __init__(self, weight_dict, loss_loc, box_xyxy):
        """ Create the criterion.
        Parameters:
            weight_dict: dict containing as key the names of the losses and as values their relative weight.
        """
        super().__init__()
        self.weight_dict = weight_dict

        self.box_xyxy = box_xyxy

        self.loss_map = {'loss_boxes': self.loss_boxes}

        self.loss_loc = self.loss_map[loss_loc]

    def loss_boxes(self, outputs, target_boxes, num_pos):
        """Compute the losses related to the bounding boxes (the L1 regression loss and the GIoU loss)"""
        assert 'pred_boxes' in outputs
        src_boxes = outputs['pred_boxes'] # [B, #query, 4]
        target_boxes = target_boxes[:, None].expand_as(src_boxes)

        src_boxes = src_boxes.reshape(-1, 4) # [B*#query, 4]
        target_boxes = target_boxes.reshape(-1, 4) #[B*#query, 4]

        losses = {}
        loss_bbox = F.l1_loss(src_boxes, target_boxes, reduction='none')
        losses['l1'] = loss_bbox.sum() / num_pos

        if not self.box_xyxy:
            src_boxes = box_ops.box_cxcywh_to_xyxy(src_boxes)
            target_boxes = box_ops.box_cxcywh_to_xyxy(target_boxes)
        loss_giou = 1 - box_ops.box_pair_giou(src_boxes, target_boxes)
        losses['giou'] = (loss_giou[:, None]).sum() / num_pos
        return losses

    def loss_distillation(self, aux_maps_outputs, target_map, num_pos):
        """Compute the losses related to the bounding boxes (the L1 regression loss and the GIoU loss)"""
        assert 'maps' in aux_maps_outputs

        losses = {}
        aux_maps = aux_maps_outputs['maps']
        b, _, h, w = aux_maps.size()
        aux_maps = aux_maps.reshape(b, -1)
        target_map = target_map.reshape(b, -1)

        # aux_maps_norm = F.layer_norm(aux_maps, (400,))
        # target_map_norm = F.layer_norm(target_map, (400,))
        # loss_distillation= F.l1_loss(aux_maps, target_map, reduction='mean')

        aux_maps = torch.softmax(aux_maps * 5, dim=1)
        target_map = torch.softmax(target_map * 5, dim=1)
        loss_distillation = target_map * torch.log(target_map + 1e-12) - target_map * torch.log(aux_maps + 1e-12)
        loss_distillation = loss_distillation.sum(1).mean()


        # losses['l1_distillation'] = loss_distillation.sum() / num_pos
        losses['kld_distillation'] = loss_distillation.sum() / num_pos
        return losses


    def forward(self, outputs, targets):
        """ This performs the loss computation.
        """
        gt_boxes = targets['bbox']
        pred_boxes = outputs['pred_boxes']

        losses = {}
        B, Q, _ = pred_boxes.shape
        num_pos = avg_across_gpus(pred_boxes.new_tensor(B*Q))
        loss = self.loss_loc(outputs, gt_boxes, num_pos)
        losses.update(loss)

        # Apply the loss function to the outputs from all the stages
        if 'aux_outputs' in outputs:
            for i, aux_outputs in enumerate(outputs['aux_outputs']):
                l_dict = self.loss_loc(aux_outputs, gt_boxes, num_pos)
                l_dict = {k + f'_{i}': v for k, v in l_dict.items()}
                losses.update(l_dict)

        if 'cor_maps' in outputs:
            for i, aux_maps_outputs in enumerate(outputs['cor_maps']):
                scale_size = aux_maps_outputs['maps'].size(-1)
                scale_gt_int = scale_size * gt_boxes
                scale_gt_int[:, 0] = scale_gt_int[:, 0] - 0.5 * scale_gt_int[:, 2]
                scale_gt_int[:, 1] = scale_gt_int[:, 1] - 0.5 * scale_gt_int[:, 3]
                scale_gt_int = scale_gt_int.type(torch.LongTensor)
                # print('scale_{}_gt_int:{}'.format(i, scale_gt_int))

                target_map = outputs['cor_maps'][i]['maps'].detach()
                target_map_zero = torch.zeros_like(target_map)

                for b in range(B):
                    target_map_zero[b, :, scale_gt_int[b, 1]:scale_gt_int[b, 1] + scale_gt_int[b, 3],
                    scale_gt_int[b, 0]:scale_gt_int[b, 0] + scale_gt_int[b, 2]] = 1
                l_dict = self.loss_distillation(aux_maps_outputs, target_map_zero, num_pos)
                l_dict = {k + f'_{i}': v for k, v in l_dict.items()}
                losses.update(l_dict)

        return losses


class PostProcess(nn.Module):
    """ This module converts the model's output into the format we expect"""
    def __init__(self, box_xyxy=False):
        super().__init__()
        self.bbox_xyxy = box_xyxy

    @torch.no_grad()
    def forward(self, outputs, target_dict):
        """ Perform the computation"""
        rsz_sizes, ratios, orig_sizes = \
            target_dict['size'], target_dict['ratio'], target_dict['orig_size']
        dxdy = None if 'dxdy' not in target_dict else target_dict['dxdy']

        boxes = outputs['pred_boxes']

        assert len(boxes) == len(rsz_sizes)
        assert rsz_sizes.shape[1] == 2

        boxes = boxes.squeeze(1)

        # Convert to absolute coordinates in the original image
        if not self.bbox_xyxy:
            boxes = box_ops.box_cxcywh_to_xyxy(boxes)
        img_h, img_w = rsz_sizes.unbind(1)
        scale_fct = torch.stack([img_w, img_h, img_w, img_h], dim=1)
        boxes = boxes * scale_fct
        if dxdy is not None:
            boxes = boxes - torch.cat([dxdy, dxdy], dim=1)
        boxes = boxes.clamp(min=0)
        ratio_h, ratio_w = ratios.unbind(1)
        boxes = boxes / torch.stack([ratio_w, ratio_h, ratio_w, ratio_h], dim=1)
        if orig_sizes is not None:
            orig_h, orig_w = orig_sizes.unbind(1)
            boxes = torch.min(boxes, torch.stack([orig_w, orig_h, orig_w, orig_h], dim=1))

        return boxes


def avg_across_gpus(v, min=1):
    if is_dist_avail_and_initialized():
        torch.distributed.all_reduce(v)
    return torch.clamp(v.float() / get_world_size(), min=min).item()


class MLP(nn.Module):
    """ Very simple multi-layer perceptron (also called FFN)"""

    def __init__(self, input_dim, hidden_dim, output_dim, num_layers):
        super().__init__()
        self.num_layers = num_layers
        h = [hidden_dim] * (num_layers - 1)
        self.layers = nn.ModuleList(nn.Linear(n, k) for n, k in zip([input_dim] + h, h + [output_dim]))

    def forward(self, x):
        for i, layer in enumerate(self.layers):
            x = F.relu(layer(x)) if i < self.num_layers - 1 else layer(x)
        return x




def build_vgmodel(args):
    device = torch.device(args.device)

    model = Highlighter(pretrained_weights=args.load_weights_path, args=args)

    weight_dict = {'loss_cls': 1, 'l1': args.bbox_loss_coef}
    weight_dict['giou'] = args.giou_loss_coef
    weight_dict.update(args.other_loss_coefs)
    if args.aux_loss:
        aux_weight_dict = {}
        for i in range(args.dec_layers - 1):
            aux_weight_dict.update({k + f'_{i}': v for k, v in weight_dict.items()})
        weight_dict.update(aux_weight_dict)
    # cnn cor
    weight_dict['kld_distillation_0'] = 1.0
    weight_dict['kld_distillation_1'] = 1.0
    weight_dict['kld_distillation_2'] = 1.0
    # tr cor
    weight_dict['kld_distillation_3'] = 1.0
    weight_dict['kld_distillation_4'] = 1.0
    weight_dict['kld_distillation_5'] = 1.0
    weight_dict['kld_distillation_6'] = 1.0
    weight_dict['kld_distillation_7'] = 1.0
    weight_dict['kld_distillation_8'] = 1.0
    criterion = VGCriterion(weight_dict=weight_dict, loss_loc=args.loss_loc, box_xyxy=args.box_xyxy)
    criterion.to(device)

    postprocessor = PostProcess(args.box_xyxy)

    return model, criterion, postprocessor
