#!/usr/bin/env python3
# -*- coding: utf-8 -*-
# Copyright (c) BaseDetection, Inc. and its affiliates. All Rights Reserved

import logging
import random
from collections import OrderedDict

import numpy as np

import torch
import torch.distributed as dist
import torch.nn as nn
import torch.nn.functional as F

from cvpods.layers import ShapeSpec, generalized_batched_nms
from cvpods.modeling.basenet import basenet
from cvpods.modeling.postprocessing import detector_postprocess
from cvpods.structures import Boxes, ImageList, Instances
from cvpods.utils import comm, log_first_n


@basenet
class YOLOv3(nn.Module):
    """
    YOLOv3 model. Darknet 53 is the default backbone of this model.
    """
    def __init__(self, cfg):
        super(YOLOv3, self).__init__()
        self.device = torch.device(cfg.MODEL.DEVICE)

        self.num_classes = cfg.MODEL.YOLO.CLASSES

        self.backbone = cfg.build_backbone(
            cfg, input_shape=ShapeSpec(channels=len(cfg.MODEL.PIXEL_MEAN)))
        backbone_shape = self.backbone.output_shape
        self.in_features = cfg.MODEL.YOLO.IN_FEATURES

        # out 0
        out_filter_0 = len(
            cfg.MODEL.YOLO.ANCHORS[0]) * (5 + cfg.MODEL.YOLO.CLASSES)
        self.out0 = self._make_embedding(
            [512, 1024], backbone_shape[-1], out_filter_0)

        # out 1
        out_filter_1 = len(
            cfg.MODEL.YOLO.ANCHORS[1]) * (5 + cfg.MODEL.YOLO.CLASSES)
        self.out1_cbl = self._make_cbl(512, 256, 1)
        self.out1_upsample = nn.Upsample(scale_factor=2, mode='nearest')
        self.out1 = self._make_embedding(
            [256, 512], backbone_shape[-2] + 256, out_filter_1)

        # out 2
        out_filter_2 = len(
            cfg.MODEL.YOLO.ANCHORS[2]) * (5 + cfg.MODEL.YOLO.CLASSES)
        self.out2_cbl = self._make_cbl(256, 128, 1)
        self.out2_upsample = nn.Upsample(scale_factor=2, mode='nearest')
        self.out2 = self._make_embedding(
            [128, 256], backbone_shape[-3] + 128, out_filter_2)

        pixel_mean = torch.Tensor(cfg.MODEL.PIXEL_MEAN).to(self.device).view(
            3, 1, 1)
        pixel_std = torch.Tensor(cfg.MODEL.PIXEL_STD).to(self.device).view(
            3, 1, 1)
        self.normalizer = lambda x: (x / 255. - pixel_mean) / pixel_std

        self.loss_evaluators = [
            YOLOHead(cfg, anchor, level) for level, anchor in enumerate(cfg.MODEL.YOLO.ANCHORS)]

        self.conf_threshold = cfg.MODEL.YOLO.CONF_THRESHOLD
        self.nms_threshold = cfg.MODEL.YOLO.NMS_THRESHOLD
        self.nms_type = cfg.MODEL.NMS_TYPE

        self.size = 512
        self.multi_size = [320, 352, 384, 416, 448, 480, 512, 544, 576, 608]
        self.change_iter = 10
        self.iter = 0
        self.max_iter = cfg.SOLVER.LR_SCHEDULER.MAX_ITER

        self.to(self.device)

    def _make_cbl(self, _in, _out, ks):
        ''' cbl = conv + batch_norm + leaky_relu
        '''
        pad = (ks - 1) // 2 if ks else 0
        return nn.Sequential(OrderedDict([
            ("conv", nn.Conv2d(_in, _out, kernel_size=ks,
                               stride=1, padding=pad, bias=False)),
            ("bn", nn.BatchNorm2d(_out)),
            ("relu", nn.LeakyReLU(0.1)),
        ]))

    def _make_embedding(self, filters_list, in_filters, out_filter):
        m = nn.ModuleList([
            self._make_cbl(in_filters, filters_list[0], 1),
            self._make_cbl(filters_list[0], filters_list[1], 3),
            self._make_cbl(filters_list[1], filters_list[0], 1),
            self._make_cbl(filters_list[0], filters_list[1], 3),
            self._make_cbl(filters_list[1], filters_list[0], 1),
            self._make_cbl(filters_list[0], filters_list[1], 3)])
        m.add_module("conv_out", nn.Conv2d(filters_list[1], out_filter, kernel_size=1,
                                           stride=1, padding=0, bias=True))
        return m

    def preprocess_image(self, batched_inputs, training):
        """
        Normalize, pad and batch the input images.
        """
        images = [x["image"].to(self.device) for x in batched_inputs]
        bs = len(images)
        images = [self.normalizer(x) for x in images]

        images = ImageList.from_tensors(
            images, size_divisibility=0, pad_ref_long=True)

        # sync image size for all gpus
        comm.synchronize()
        if training and self.iter % self.change_iter == 0:
            if self.iter < self.max_iter - 20000:
                meg = torch.LongTensor(1).to(self.device)
                comm.synchronize()
                if comm.is_main_process():
                    size = np.random.choice(self.multi_size)
                    meg.fill_(size)

                if comm.get_world_size() > 1:
                    comm.synchronize()
                    dist.broadcast(meg, 0)
                self.size = meg.item()

                comm.synchronize()
            else:
                self.size = 608

        if training:

            # resize image inputs
            modes = ['bilinear', 'nearest', 'bicubic', 'area']
            mode = modes[random.randrange(4)]
            if mode == 'bilinear' or mode == 'bicubic':
                images.tensor = F.interpolate(
                    images.tensor, size=[self.size, self.size], mode=mode, align_corners=False)
            else:
                images.tensor = F.interpolate(images.tensor, size=[self.size, self.size], mode=mode)

            if "instances" in batched_inputs[0]:
                gt_instances = [
                    x["instances"].to(self.device) for x in batched_inputs
                ]
            elif "targets" in batched_inputs[0]:
                log_first_n(
                    logging.WARN,
                    "'targets' in the model inputs is now renamed to 'instances'!",
                    n=10)
                gt_instances = [
                    x["targets"].to(self.device) for x in batched_inputs
                ]
            else:
                gt_instances = None

            targets = [
                torch.cat(
                    [instance.gt_classes.float().unsqueeze(-1), instance.gt_boxes.tensor], dim=-1
                )
                for instance in gt_instances
            ]
            labels = torch.zeros((bs, 100, 5))
            for i, target in enumerate(targets):
                labels[i][:target.shape[0]] = target
            labels[:, :, 1:] = labels[:, :, 1:] / 512. * self.size
        else:
            labels = None

        self.iter += 1
        return images, labels

    def forward(self, batched_inputs):
        """
        Args:
            batched_inputs: a list, batched outputs of :class:`DatasetMapper` .
                Each item in the list contains the inputs for one image.
                For now, each item in the list is a dict that contains:

                * image: Tensor, image in (C, H, W) format.
                * instances: Instances

                Other information that's included in the original dicts, such as:

                * "height", "width" (int): the output resolution of the model, used in inference.
                    See :meth:`postprocess` for details.
        Returns:
            dict[str: Tensor]:
                mapping from a named loss to a tensor storing the loss. Used during training only.
        """

        images, labels = self.preprocess_image(batched_inputs, self.training)

        # batched_inputs[0]['image'] = images.tensor[0].cpu() * 255
        # self.visualize_data(batched_inputs[0])

        x = images.tensor
        img_size = x.shape[-2:]

        def _branch(_embedding, _in):
            for i, e in enumerate(_embedding):
                _in = e(_in)
                if i == 4:
                    out_branch = _in
            return _in, out_branch

        #  backbone
        # x2, x1, x0 = self.backbone(x)
        out_features = self.backbone(x)
        features = [out_features[f] for f in self.in_features]
        [x2, x1, x0] = features
        #  yolo branch 0
        out0, out0_branch = _branch(self.out0, x0)
        #  yolo branch 1
        x1_in = self.out1_cbl(out0_branch)
        x1_in = self.out1_upsample(x1_in)
        x1_in = torch.cat([x1_in, x1], 1)
        out1, out1_branch = _branch(self.out1, x1_in)
        #  yolo branch 2
        x2_in = self.out2_cbl(out1_branch)
        x2_in = self.out2_upsample(x2_in)
        x2_in = torch.cat([x2_in, x2], 1)
        out2, out2_branch = _branch(self.out2, x2_in)

        outputs = [out0, out1, out2]

        if self.training:
            losses = [
                loss_evaluator(out, labels, img_size) for out, loss_evaluator in zip(
                    outputs, self.loss_evaluators)
            ]
            keys = ["loss_x", "loss_y", "loss_w",
                    "loss_h", "loss_conf", "loss_cls"]
            losses_dict = {}
            for key in keys:
                losses_dict[key] = sum([loss[key] for loss in losses])
            return losses_dict
        else:
            predictions_list = [loss_evaluator(out, labels, img_size) for
                                out, loss_evaluator in zip(outputs, self.loss_evaluators)]

            predictions = torch.cat(predictions_list, 1)
            detections = postprocess(predictions,
                                     self.num_classes,
                                     self.conf_threshold,
                                     self.nms_threshold,
                                     nms_type=self.nms_type)

            results = []
            for idx, out in enumerate(detections):
                if out is None:
                    out = x.new_zeros((0, 7))
                # image_size = images.image_sizes[idx]
                image_size = img_size
                result = Instances(image_size)
                result.pred_boxes = Boxes(out[:, :4])
                result.scores = out[:, 5] * out[:, 4]
                result.pred_classes = out[:, -1]
                results.append(result)

            processed_results = []
            for results_per_image, input_per_image, image_size in zip(
                    results, batched_inputs, images.image_sizes):
                height = input_per_image.get("height", image_size[0])
                width = input_per_image.get("width", image_size[1])
                r = detector_postprocess(results_per_image, height, width)
                processed_results.append({"instances": r})

            return processed_results


class YOLOHead(nn.Module):
    def __init__(self, cfg, anchors, level):
        super(YOLOHead, self).__init__()
        self.level = level
        self.all_anchors = np.array(cfg.MODEL.YOLO.ANCHORS).reshape([-1, 2])
        self.anchors = anchors
        self.ref_anchors = np.zeros((len(self.all_anchors), 4))
        self.ref_anchors[:, 2:] = self.all_anchors
        self.ref_anchors = torch.from_numpy(self.ref_anchors)

        self.num_anchors = len(anchors)
        self.num_classes = cfg.MODEL.YOLO.CLASSES
        self.bbox_attrs = 5 + self.num_classes

        self.ignore_threshold = cfg.MODEL.YOLO.IGNORE_THRESHOLD
        self.lambda_xy = 1.0
        self.lambda_wh = 1.0
        self.lambda_conf = 1.0
        self.lambda_cls = 1.0

        self.mse_loss = nn.MSELoss(reduction="none")
        self.l1_loss = nn.L1Loss(reduction="none")
        self.bce_loss = nn.BCELoss(reduction="none")

    def forward(self, input, targets=None, image_size=(416, 416)):

        bs = input.size(0)
        in_h = input.size(2)
        in_w = input.size(3)
        stride_h = image_size[1] / in_h
        stride_w = image_size[0] / in_w
        # scaled_anchors = [(a_w / stride_w, a_h / stride_h)
        #                  for a_w, a_h in self.anchors]

        scaled_anchors = [(a_w, a_h)
                          for a_w, a_h in self.anchors]

        prediction = input.view(bs, self.num_anchors,
                                self.bbox_attrs, in_h, in_w).permute(0, 1, 3, 4, 2).contiguous()

        # Get outputs
        x = torch.sigmoid(prediction[..., 0])          # Center x
        y = torch.sigmoid(prediction[..., 1])          # Center y
        w = prediction[..., 2]                         # Width
        h = prediction[..., 3]                         # Height
        conf = torch.sigmoid(prediction[..., 4])       # Conf
        pred_cls = torch.sigmoid(prediction[..., 5:])  # Cls pred.

        FloatTensor = lambda x: torch.FloatTensor(x).to(pred_cls.device) # noqa
        LongTensor = lambda x: torch.LongTensor(x).to(pred_cls.device) # noqa

        # Calculate offsets for each grid
        grid_x = FloatTensor(torch.linspace(0, in_w - 1, in_w).repeat(in_h, 1).repeat(
            bs * self.num_anchors, 1, 1).view(x.shape))
        grid_y = FloatTensor(torch.linspace(0, in_h - 1, in_h).repeat(in_w, 1).t().repeat(
            bs * self.num_anchors, 1, 1).view(y.shape))
        # Calculate anchor w, h
        anchor_w = FloatTensor(scaled_anchors).index_select(1, LongTensor([0]))
        anchor_h = FloatTensor(scaled_anchors).index_select(1, LongTensor([1]))
        anchor_w = anchor_w.repeat(bs, 1).repeat(
            1, 1, in_h * in_w).view(w.shape)
        anchor_h = anchor_h.repeat(bs, 1).repeat(
            1, 1, in_h * in_w).view(h.shape)
        # Add offset and scale with anchors
        pred_boxes = prediction[..., :4].clone()
        pred_boxes[..., 0] = x.data + grid_x
        pred_boxes[..., 1] = y.data + grid_y
        pred_boxes[..., 2] = torch.exp(w.data) * anchor_w
        pred_boxes[..., 3] = torch.exp(h.data) * anchor_h
        pred_boxes[..., 0] *= stride_w
        pred_boxes[..., 1] *= stride_h
        pred_boxes = pred_boxes.data

        if targets is not None:
            #  build target
            mask, obj_mask, \
                tx, ty, tw, th, \
                tgt_scale, tcls = self.get_target(targets, pred_boxes, image_size,
                                                  in_w, in_h,
                                                  stride_w, stride_h,
                                                  self.ignore_threshold)

            mask, obj_mask = mask.cuda(), obj_mask.cuda()
            tx, ty, tw, th = tx.cuda(), ty.cuda(), tw.cuda(), th.cuda()
            tgt_scale, tcls = tgt_scale.cuda(), tcls.cuda()

            loss_x = (mask * tgt_scale * self.bce_loss(x * mask, tx * mask)).sum() / bs
            loss_y = (mask * tgt_scale * self.bce_loss(y * mask, ty * mask)).sum() / bs
            loss_w = (mask * tgt_scale * self.l1_loss(w * mask, tw * mask)).sum() / bs
            loss_h = (mask * tgt_scale * self.l1_loss(h * mask, th * mask)).sum() / bs

            loss_conf = (obj_mask * self.bce_loss(conf, mask)).sum() / bs

            loss_cls = self.bce_loss(pred_cls[mask == 1], tcls[mask == 1]).sum() / bs

            #  total loss = losses * weight
            loss = {
                "loss_x": loss_x * self.lambda_xy,
                "loss_y": loss_y * self.lambda_xy,
                "loss_w": loss_w * self.lambda_wh,
                "loss_h": loss_h * self.lambda_wh,
                "loss_conf": loss_conf * self.lambda_conf,
                "loss_cls": loss_cls * self.lambda_cls,
            }
            return loss
        else:
            # Results
            output = torch.cat((pred_boxes.view(bs, -1, 4),
                                conf.view(bs, -1, 1), pred_cls.view(bs, -1, self.num_classes)), -1)
            return output.data

    def get_target(self, target, pred_boxes, img_size,
                   in_w, in_h, stride_w, stride_h, ignore_threshold):

        FloatTensor = lambda x: torch.FloatTensor(x).to(pred_boxes.device)  # noqa

        bs = target.size(0)

        mask = torch.zeros(bs, self.num_anchors, in_h,
                           in_w, requires_grad=False)
        obj_mask = torch.ones(bs, self.num_anchors,
                              in_h, in_w, requires_grad=False)
        tx = torch.zeros(bs, self.num_anchors, in_h, in_w, requires_grad=False)
        ty = torch.zeros(bs, self.num_anchors, in_h, in_w, requires_grad=False)
        tw = torch.zeros(bs, self.num_anchors, in_h, in_w, requires_grad=False)
        th = torch.zeros(bs, self.num_anchors, in_h, in_w, requires_grad=False)
        tgt_scale = torch.zeros(bs, self.num_anchors, in_h, in_w, requires_grad=False)

        tcls = torch.zeros(bs, self.num_anchors, in_h, in_w,
                           self.num_classes, requires_grad=False)
        nlabel = (target.sum(dim=2) > 0).sum(dim=1)
        gx_all = (target[:, :, 1] + target[:, :, 3]) / 2.0   # center x
        gy_all = (target[:, :, 2] + target[:, :, 4]) / 2.0  # center y
        gw_all = (target[:, :, 3] - target[:, :, 1])        # width
        gh_all = (target[:, :, 4] - target[:, :, 2])        # height
        gi_all = (gx_all / stride_w).to(torch.int16)
        gj_all = (gy_all / stride_h).to(torch.int16)

        for b in range(bs):
            n = int(nlabel[b])
            if n == 0:
                continue

            truth_box = FloatTensor(np.zeros((n, 4)))
            truth_box[:, 2] = gw_all[b, :n]
            truth_box[:, 3] = gh_all[b, :n]
            truth_i = gi_all[b, :n]
            truth_j = gj_all[b, :n]

            anchor_ious_all = bboxes_iou(truth_box.cpu(),
                                         self.ref_anchors.type_as(truth_box.cpu()), xyxy=False)
            best_n_all = np.argmax(anchor_ious_all, axis=1)
            best_n = best_n_all % 3
            best_n_mask = ((best_n_all // 3) == self.level)

            truth_box[:n, 0] = gx_all[b, :n]
            truth_box[:n, 1] = gy_all[b, :n]
            pred_box = pred_boxes[b]

            pred_ious = bboxes_iou(pred_box.view(-1, 4),
                                   truth_box, xyxy=False)

            pred_best_iou, _ = pred_ious.max(dim=1)
            pred_best_iou = (pred_best_iou > ignore_threshold)
            pred_best_iou = pred_best_iou.view(pred_box.shape[:3])
            obj_mask[b] = ~pred_best_iou

            if sum(best_n_mask) == 0:
                continue

            for t in range(best_n.shape[0]):
                if best_n_mask[t] == 1:
                    gi, gj = truth_i[t], truth_j[t]
                    gx, gy = gx_all[b, t], gy_all[b, t]
                    gw, gh = gw_all[b, t], gh_all[b, t]

                    a = best_n[t]

                    # Masks
                    mask[b, a, gj, gi] = 1
                    obj_mask[b, a, gj, gi] = 1

                    # Coordinates
                    tx[b, a, gj, gi] = gx / stride_w - gi
                    ty[b, a, gj, gi] = gy / stride_h - gj
                    # Width and height
                    tw[b, a, gj, gi] = torch.log(gw / self.anchors[a][0] + 1e-16)
                    th[b, a, gj, gi] = torch.log(gh / self.anchors[a][1] + 1e-16)

                    tgt_scale[b, a, gj, gi] = 2.0 - gw * gh / (img_size[0] * img_size[1])
                    # One-hot encoding of label
                    tcls[b, a, gj, gi, int(target[b, t, 0])] = 1

        return mask, obj_mask, tx, ty, tw, th, tgt_scale, tcls


def bboxes_iou(bboxes_a, bboxes_b, xyxy=True):
    """Calculate the Intersection of Unions (IoUs) between bounding boxes.
    IoU is calculated as a ratio of area of the intersection
    and area of the union.

    Args:
        bbox_a (array): An array whose shape is :math:`(N, 4)`.
            :math:`N` is the number of bounding boxes.
            The dtype should be :obj:`numpy.float32`.
        bbox_b (array): An array similar to :obj:`bbox_a`,
            whose shape is :math:`(K, 4)`.
            The dtype should be :obj:`numpy.float32`.
    Returns:
        array:
        An array whose shape is :math:`(N, K)`. \
        An element at index :math:`(n, k)` contains IoUs between \
        :math:`n` th bounding box in :obj:`bbox_a` and :math:`k` th bounding \
        box in :obj:`bbox_b`.

    from: https://github.com/chainer/chainercv
    """
    if bboxes_a.shape[1] != 4 or bboxes_b.shape[1] != 4:
        raise IndexError

    if xyxy:
        tl = torch.max(bboxes_a[:, None, :2], bboxes_b[:, :2])
        br = torch.min(bboxes_a[:, None, 2:], bboxes_b[:, 2:])
        area_a = torch.prod(bboxes_a[:, 2:] - bboxes_a[:, :2], 1)
        area_b = torch.prod(bboxes_b[:, 2:] - bboxes_b[:, :2], 1)
    else:
        tl = torch.max((bboxes_a[:, None, :2] - bboxes_a[:, None, 2:] / 2),
                       (bboxes_b[:, :2] - bboxes_b[:, 2:] / 2))
        br = torch.min((bboxes_a[:, None, :2] + bboxes_a[:, None, 2:] / 2),
                       (bboxes_b[:, :2] + bboxes_b[:, 2:] / 2))

        area_a = torch.prod(bboxes_a[:, 2:], 1)
        area_b = torch.prod(bboxes_b[:, 2:], 1)
    en = (tl < br).type(tl.type()).prod(dim=2)
    area_i = torch.prod(br - tl, 2) * en  # * ((tl < br).all())

    return area_i / (area_a[:, None] + area_b - area_i)


def postprocess(prediction, num_classes, conf_thre=0.7, nms_thre=0.45, nms_type='normal'):
    """
    Postprocess for the output of YOLO model
    perform box transformation, specify the class for each detection,
    and perform class-wise non-maximum suppression.
    Args:
        prediction (torch tensor): The shape is :math:`(N, B, 4)`.
            :math:`N` is the number of predictions,
            :math:`B` the number of boxes. The last axis consists of
            :math:`xc, yc, w, h` where `xc` and `yc` represent a center
            of a bounding box.
        num_classes (int):
            number of dataset classes.
        conf_thre (float):
            confidence threshold ranging from 0 to 1,
            which is defined in the config file.
        nms_thre (float):
            IoU threshold of non-max suppression ranging from 0 to 1.

    Returns:
        output (list of torch tensor):

    """
    box_corner = prediction.new(prediction.shape)
    box_corner[:, :, 0] = prediction[:, :, 0] - prediction[:, :, 2] / 2
    box_corner[:, :, 1] = prediction[:, :, 1] - prediction[:, :, 3] / 2
    box_corner[:, :, 2] = prediction[:, :, 0] + prediction[:, :, 2] / 2
    box_corner[:, :, 3] = prediction[:, :, 1] + prediction[:, :, 3] / 2
    prediction[:, :, :4] = box_corner[:, :, :4]

    output = [None for _ in range(len(prediction))]
    for i, image_pred in enumerate(prediction):

        # If none are remaining => process next image
        if not image_pred.size(0):
            continue
        # Get score and class with highest confidence
        class_conf, class_pred = torch.max(
            image_pred[:, 5:5 + num_classes], 1, keepdim=True)

        conf_mask = (image_pred[:, 4] * class_conf.squeeze() >= conf_thre).squeeze()
        # Detections ordered as (x1, y1, x2, y2, obj_conf, class_conf, class_pred)
        detections = torch.cat(
            (image_pred[:, :5], class_conf, class_pred.float()), 1)
        detections = detections[conf_mask]
        if not detections.size(0):
            continue

        confidence = detections[:, 4] * detections[:, 5]
        nms_out_index = generalized_batched_nms(detections[:, :4], confidence,
                                                detections[:, -1], nms_thre,
                                                nms_type=nms_type)
        detections[:, 4] = confidence / detections[:, 5]

        detections = detections[nms_out_index]

        # Iterate through all predicted classes
        unique_labels = detections[:, -1].unique()

        for c in unique_labels:
            # Get the detections with the particular class
            detections_class = detections[detections[:, -1] == c]
            if output[i] is None:
                output[i] = detections_class
            else:
                output[i] = torch.cat((output[i], detections_class))

    return output
