# Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved
"""
DETR model and criterion classes.
"""
import torch
import torch.nn.functional as F
from torch import nn
import numpy as np
import copy

import matplotlib.pyplot as plt

from .backbone import build_backbone
from .ranking_losses import RankSort
from .rank_loss import RankLoss
from .sort_loss import SortLoss
from .position_encoding import build_position_encoding
from .postprocess import PostProcess

from .transformer import build_transformer
from .deformable_transformer import build_deformable_transformer


class HETR(nn.Module):
    """ This is the DETR module that performs object detection """
    def __init__(self, backbone, pe, transformer, num_queries, aux_loss=False, args=None):
        """ Initializes the model.
        Parameters:
            backbone: torch module of the backbone to be used. See backbone.py
            transformer: torch module of the transformer architecture. See transformer.py
            num_classes: number of object classes
            num_queries: number of object queries, ie detection slot. This is the maximal number of objects
                         DETR can detect in a single image. For COCO, we recommend 100 queries.
            aux_loss: True if auxiliary decoding losses (loss at each decoder layer) are to be used.
        """
        super().__init__()
        self.phaseII = args.phaseII
        self.num_queries = num_queries
        self.transformer = transformer
        hidden_dim = self.transformer.d_model
        self.hidden_dim = hidden_dim

        self.junction_head = nn.Conv2d(hidden_dim, 1, 1)
        self.line_head = nn.Conv2d(hidden_dim, 1, 1)
        # self.junction_head = nn.ModuleList([nn.Conv2d(hidden_dim, 1, 1) for _ in range(4)])
        # self.line_head = nn.ModuleList([nn.Conv2d(hidden_dim, 1, 1) for _ in range(4)])

        input_proj_list = []
        for c in backbone.num_channels:
            input_proj_list.append(nn.Conv2d(c, hidden_dim, kernel_size=1))
        self.input_proj = nn.ModuleList(input_proj_list)

        self.register_buffer('cls_mask', torch.zeros((1, 1), dtype=bool))
        self.backbone = backbone
        self.aux_loss = aux_loss
        self.encoder_pos = pe

        self.class_embed = nn.Linear(self.transformer.d_model, 1)
        self.lines_embed = MLP(self.transformer.d_model, self.transformer.d_model, 4, 3)

        self.class_embed_p2 = nn.Linear(self.transformer.d_model, 1)
        self.lines_embed_p2 = MLP(self.transformer.d_model, self.transformer.d_model, 4, 3)

    def forward(self, imgs, score_lines=None):
        features = self.feature_extract(imgs)
        out = self.hetr(features, score_lines)
        return out

    def feature_extract(self, imgs):
        features = self.backbone(imgs)
        # features = [self.backbone(torch.rot90(imgs, r, (2, 3))) for r in [-1, 0, 1]]
        # features = [sum([torch.rot90(f[i], r, (2, 3)) for f, r in zip(features, [1, 0, -1])]) for i in range(4)]

        # plt.figure(figsize=(15, 15))
        # f1 = features[0][0].cpu().numpy()
        # for j in range(16):
        #     ax = plt.subplot(4, 4, j + 1, )
        #     plt.imshow(f1[j, :, :])
        #     plt.axis("off")
        # plt.tight_layout()
        # plt.show()
        # plt.close()
        return features

    def hetr(self, features, score_lines):
        features = [self.input_proj[i](f) for i, f in enumerate(features)]
        pred_line_map = [self.line_head(f).squeeze(1) for i, f in enumerate(features)]
        pred_junction_map = [self.junction_head(f).squeeze(1) for i, f in enumerate(features)]

        # plt.clf()
        # plt.subplot(221)
        # plt.imshow(pred_line_map[0][0].cpu().numpy())
        # plt.subplot(222)
        # plt.imshow(pred_line_map[1][0].cpu().numpy())
        # plt.subplot(223)
        # plt.imshow(pred_line_map[2][0].cpu().numpy())
        # plt.subplot(224)
        # plt.imshow(pred_line_map[3][0].cpu().numpy())
        # plt.tight_layout()
        # plt.show()
        # plt.close()

        srcs = []
        masks = []
        pos_embeds = []
        for i, f in enumerate(features):
            bs, c, w, h = f.shape
            mask = torch.zeros([bs, h, w], device=f.device, dtype=torch.bool)
            masks.append(mask)
            pos_embeds.append(self.encoder_pos(f))
            srcs.append(f)

        if self.transformer.__class__.__name__ == "Transformer":
            memory = self.transformer(srcs[0], masks[0], pos_embeds[0], query_embeds=None)  # detr
        elif self.transformer.__class__.__name__ == "DeformableTransformer":
            # features, _ = self.transformer(srcs, masks, pos_embeds, self.query_embed.weight)  # deformable_detr
            outputs_class, outputs_lines, memory = self.transformer.forward_encoder(srcs, masks, pos_embeds)  # deformable_detr

        if self.phaseII:
            outputs_class_p2, outputs_lines_p2, outputs_lines_p1, inter_references = \
                self.transformer.forward_decoder(outputs_class, outputs_lines, srcs, masks, pos_embeds, memory)

        out = {'pred_cls': outputs_class[-1], 'pred_lines': outputs_lines[-1],
               'pred_line_map': pred_line_map, 'pred_junction_map': pred_junction_map}

        if self.aux_loss:
            out['aux_outputs'] = self._set_aux_loss(outputs_class, outputs_lines)

        if self.phaseII:
            out['inter_references'] = inter_references
            out['pred_cls_p2'] = outputs_class_p2[-1]
            out['pred_lines_p2'] = outputs_lines_p2[-1]  # 仅重新rank，不重新predict。outputs_lines_p2[-1]
            out['pred_lines_p1'] = outputs_lines_p1
            if self.aux_loss:
                out['aux_outputs_p2'] = self._set_aux_loss(outputs_class_p2, outputs_lines_p2)

        return out


    @torch.jit.unused
    def _set_aux_loss(self, outputs_class, outputs_lines):
        # this is a workaround to make torchscript happy, as torchscript
        # doesn't support dictionary with non-homogeneous values, such
        # as a dict having both a Tensor and a list.
        return [{'pred_cls': a, 'pred_lines': b}
                for a, b in zip(outputs_class[:-1], outputs_lines[:-1])]


class SetCriterion(nn.Module):

    def __init__(self, weight_dict, cls_weight, score_weight, line_weight, endpoint_weight, phaseII):

        super().__init__()
        self.register_buffer("cls_weight", torch.Tensor([cls_weight]))
        self.register_buffer("score_weight", torch.Tensor([score_weight]))
        self.register_buffer("line_weight", torch.Tensor(line_weight))
        self.register_buffer("endpoint_weight", torch.Tensor(endpoint_weight))
        self.cls_loss = nn.BCEWithLogitsLoss()
        self.line_loss = nn.L1Loss()
        self.class_lossII = nn.BCEWithLogitsLoss()
        self.line_lossII = nn.L1Loss()
        self.angle_loss = nn.MSELoss()
        self.dis_loss = nn.MSELoss()
        self.rank_loss = RankSort()

        self.weight_dict = weight_dict
        self.phaseII = phaseII

    # def cls_loss(self, pred_cls, label_cls):  # selected version
    #     n = (label_cls == 1.).sum(-1).max() * label_cls.shape[0]
    #     pred_cls, label_cls = pred_cls.flatten(), label_cls.flatten()
    #     sample_indice = torch.where(label_cls == 0)[0]
    #     sample_indice = sample_indice[torch.randperm(sample_indice.shape[0])[:self.cls_weight.to(torch.int32) * n]]
    #     sample_indice = torch.cat((sample_indice, torch.where(label_cls == 1)[0]))
    #     return F.binary_cross_entropy_with_logits(pred_cls[sample_indice], label_cls[sample_indice])

    def get_map_loss(self, pred_junction_map, pred_line_map, label_junction_map, label_line_map):
        loss_junction_map = sum(
            [F.binary_cross_entropy_with_logits(pj, lj, pos_weight=self.endpoint_weight[i]) for i, (pj, lj) in
             enumerate(zip(pred_junction_map, label_junction_map))])
        loss_line_map = sum([F.binary_cross_entropy_with_logits(pl, ll, pos_weight=self.line_weight[i]) for i, (pl, ll) in
                             enumerate(zip(pred_line_map, label_line_map))])
        return loss_junction_map, loss_line_map

    # ancient version
    # def rank_loss(self, pred_cls, label_cls, pred_lines, label_lines):
    #     index = label_cls == 1
    #     pred_cls, label_cls, pred_lines, label_lines = pred_cls[index], label_cls[index], pred_lines[index], label_lines[index]
    #     pred_lines = pred_lines.reshape(-1, 2, 2) * 128
    #     label_lines = label_lines.reshape(-1, 2, 2) * 128
    #     diff = ((pred_lines[:, :, None] - label_lines[:, None]) ** 2).sum(-1)
    #     diff = torch.minimum(
    #         diff[:, 0, 0] + diff[:, 1, 1], diff[:, 0, 1] + diff[:, 1, 0]
    #     )
    #     loss = torch.sum((pred_cls[:, None] > pred_cls) * (diff[:, None] > diff + 1)) / torch.sum(diff[:, None] > diff + 1)
    #     return loss

    def get_contiguous_labels(self, pred_lines, label_lines):
        out_scores = torch.zeros_like(pred_lines[:, :, 0])
        out_lines = -torch.zeros_like(pred_lines)
        for i, (pred_line, label_line) in enumerate(zip(pred_lines, label_lines)):
            pred_line = pred_line.reshape(-1, 2, 2)
            label_line = label_line[label_line.sum(-1) > 0]
            label_line = label_line.reshape(-1, 2, 2)
            diff = ((pred_line[:, None, :, None] - label_line[:, None]) ** 2).sum(-1)
            diff = torch.minimum(
                diff[:, :, 0, 0] + diff[:, :, 1, 1], diff[:, :, 0, 1] + diff[:, :, 1, 0]
            )
            value, index = torch.min(diff, 1)
            pindex = value < 10. / 128 ** 2
            out_lines[i][pindex] = label_line[index[pindex]].flatten(-2, -1)
            v = torch.clip(1 - 0.1 * value * 128 ** 2, 0)
            out_scores[i] = v
        return out_scores, out_lines

    def forward(self, outputs, label_cls, label_lines, label_junction_map, label_line_map, label_cls64=None, label_lines64=None):
        pred_cls = outputs['pred_cls']
        pred_lines = outputs['pred_lines']
        loss_cls = self.cls_loss(pred_cls, label_cls)
        loss_line = self.line_loss(pred_lines[label_lines.sum(-1) > 0], label_lines[label_lines.sum(-1) > 0])
        loss_junction_map, loss_line_map = \
            self.get_map_loss(outputs['pred_junction_map'], outputs['pred_line_map'],
                              label_junction_map, label_line_map)

        losses = {
            'loss_cls': loss_cls,
            'loss_line': loss_line,
            'loss_junction_map': loss_junction_map, 'loss_line_map': loss_line_map
        }
        if 'aux_outputs' in outputs:
            for i, aux_outputs in enumerate(outputs['aux_outputs']):
                pred_cls = aux_outputs['pred_cls']
                pred_lines = aux_outputs['pred_lines']
                l_dict = {
                    'loss_cls': self.cls_loss(pred_cls, label_cls),
                    'loss_line': self.line_loss(pred_lines[label_lines.sum(-1) > 0],
                                                label_lines[label_lines.sum(-1) > 0]),
                }
                l_dict = {k + f'_{i}': v for k, v in l_dict.items()}
                losses.update(l_dict)
        if self.phaseII:
            pred_lines_p1 = outputs['pred_lines_p1']
            pred_cls_p2 = outputs['pred_cls_p2']
            pred_lines_p2 = outputs['pred_lines_p2']
            label_scores_p2, label_lines_p2 = self.get_contiguous_labels(pred_lines_p1, label_lines)
            l_dict = {
                # 'p2_loss_cls': self.cls_loss(pred_cls_p2, (label_scores_p2 > 0).float()),
                'p2_loss_cls': RankLoss.apply(pred_cls_p2, label_scores_p2) + SortLoss.apply(pred_cls_p2,
                                                                                             label_scores_p2),
                'p2_loss_line': self.line_loss(pred_lines_p2[label_lines_p2.sum(-1) > 0],
                                            label_lines_p2[label_lines_p2.sum(-1) > 0]),
            }
            losses.update(l_dict)
            # if 'aux_outputs' in outputs:
            #     for i, aux_outputs in enumerate(outputs['aux_outputs_p2']):
            #         pred_cls_p2 = aux_outputs['pred_cls']
            #         pred_lines_p2 = aux_outputs['pred_lines']
            #         l_dict = {
            #             'p2_loss_cls': self.cls_loss(pred_cls_p2, (label_scores_p2 > 0).float()),
            #             # 'p2_loss_cls': RankLoss.apply(pred_cls_p2, label_scores_p2) + SortLoss.apply(pred_cls_p2,
            #             #                                                                      label_scores_p2),
            #             'p2_loss_line': self.line_loss(pred_lines_p2[label_lines_p2.sum(-1) > 0],
            #                                            label_lines_p2[label_lines_p2.sum(-1) > 0]),
            #         }
            #         l_dict = {k + f'_{i}': v for k, v in l_dict.items()}
            #         losses.update(l_dict)
        return losses


class MLP(nn.Module):
    """ Very simple multi-layer perceptron (also called FFN)"""

    def __init__(self, input_dim, hidden_dim, output_dim, num_layers):
        super().__init__()
        self.num_layers = num_layers
        h = [hidden_dim] * (num_layers - 1)
        self.layers = nn.ModuleList(nn.Linear(n, k) for n, k in zip([input_dim] + h, h + [output_dim]))

    def forward(self, x):
        for i, layer in enumerate(self.layers):
            x = F.relu(layer(x)) if i < self.num_layers - 1 else layer(x)
        return x


class ConvBnRelu(nn.Module):
    def __init__(self, in_planes, out_planes, ksize, stride, pad, dilation=1,
                 groups=1, has_bn=True, norm_layer=nn.SyncBatchNorm, bn_eps=1e-5,
                 has_relu=True, inplace=True, has_bias=False):
        super(ConvBnRelu, self).__init__()
        self.conv = nn.Conv2d(in_planes, out_planes, kernel_size=ksize,
                              stride=stride, padding=pad,
                              dilation=dilation, groups=groups, bias=has_bias)
        self.has_bn = has_bn
        if self.has_bn:
            self.bn = norm_layer(out_planes, eps=bn_eps)
        self.has_relu = has_relu
        if self.has_relu:
            self.relu = nn.ReLU(inplace=inplace)

    def forward(self, x):
        x = self.conv(x)
        if self.has_bn:
            x = self.bn(x)
        if self.has_relu:
            x = self.relu(x)
        return x


def build(args):
    num_classes = 1
    assert num_classes == 1

    device = torch.device(args.device)

    backbone = build_backbone(args)

    if args.deformable:
        transformer = build_deformable_transformer(args)
    else:
        transformer = build_transformer(args)

    pe = build_position_encoding(args)

    model = HETR(
        backbone,
        pe,
        transformer,
        num_queries=args.num_queries,
        aux_loss=args.aux_loss,
        args=args,
    )
    if args.phaseII:
        weight_dict = {'loss_cls': args.loss_cls_coef * 0.5,
                       'loss_line': args.loss_line_coef * 0.5,
                       # 'loss_rank': 1,
                       'loss_junction_map': args.loss_junction_map_coef * 0.5,
                       'loss_line_map': args.loss_line_map_coef * 0.5}
        if args.aux_loss:
            aux_weight_dict = {}
            for i in range(args.enc_layers - 1):
                aux_weight_dict.update({k + f'_{i}': v * args.aux_loss_coef for k, v in weight_dict.items() if k not in ['loss_junction_map', 'loss_line_map']})
            weight_dict.update(aux_weight_dict)
        p2_weight_dict = {'p2_loss_cls': args.loss_cls_coef,
                          'p2_loss_line': args.loss_line_coef,
                          }
        weight_dict.update(p2_weight_dict)
        if args.aux_loss:
            p2_aux_weight_dict = {}
            for i in range(args.enc_layers - 1):
                p2_aux_weight_dict.update({k + f'_{i}': v * args.aux_loss_coef for k, v in p2_weight_dict.items()})
            weight_dict.update(p2_aux_weight_dict)
    else:
        weight_dict = {'loss_cls': args.loss_cls_coef,
                       'loss_line': args.loss_line_coef,
                       # 'loss_rank': 1,
                       'loss_junction_map': args.loss_junction_map_coef,
                       'loss_line_map': args.loss_line_map_coef}
        if args.aux_loss:
            aux_weight_dict = {}
            for i in range(args.enc_layers - 1):
                aux_weight_dict.update({k + f'_{i}': v * args.aux_loss_coef for k, v in weight_dict.items() if k not in ['loss_junction_map', 'loss_line_map']})
            weight_dict.update(aux_weight_dict)

    criterion = SetCriterion(weight_dict=weight_dict, cls_weight=args.cls_weight, score_weight=args.score_weight,
                             line_weight=args.line_weight, endpoint_weight=args.endpoint_weight, phaseII=args.phaseII)
    criterion.to(device)
    postprocessors = PostProcess(args)
    postprocessors.to(device)
    return model, criterion, postprocessors