# Copyright 2021 Toyota Research Institute.  All rights reserved.
# Adapted from AdelaiDet:
#   https://github.com/aim-uofa/AdelaiDet
import torch
from mmcv.losses import sigmoid_focal_loss
from torch import nn
from torch.nn import functional as F

from mmcv.layers import batched_nms, get_norm
from mmcv.structures import Instances, Boxes
from torch import distributed as dist
from mmcv.utils import force_fp32
from mmcv.layers import Conv2d, batched_nms, cat, get_norm

from adzoo.bevformer.mmdet3d_plugin.dd3d.layers.iou_loss import IOULoss
from adzoo.bevformer.mmdet3d_plugin.dd3d.layers.normalization import ModuleListDial, Scale
from adzoo.bevformer.mmdet3d_plugin.dd3d.utils.comm import reduce_sum

INF = 100000000

def get_world_size() -> int:
    if not dist.is_available():
        return 1
    if not dist.is_initialized():
        return 1
    return dist.get_world_size()

def compute_ctrness_targets(reg_targets):
    if len(reg_targets) == 0:
        return reg_targets.new_zeros(len(reg_targets))
    left_right = reg_targets[:, [0, 2]]
    top_bottom = reg_targets[:, [1, 3]]
    ctrness = (left_right.min(dim=-1)[0] / left_right.max(dim=-1)[0]) * \
                 (top_bottom.min(dim=-1)[0] / top_bottom.max(dim=-1)[0])
    return torch.sqrt(ctrness)

class FCOS2DHead(nn.Module):
    def __init__(self, 
                 num_classes, 
                 input_shape,
                 num_cls_convs=4,
                 num_box_convs=4,
                 norm='BN',
                 use_deformable=False,
                 use_scale=True,
                 box2d_scale_init_factor=1.0,
                 version='v2'):
        super().__init__()

        self.num_classes = num_classes
        self.in_strides = [shape.stride for shape in input_shape]
        self.num_levels = len(input_shape)

        self.use_scale = use_scale
        self.box2d_scale_init_factor = box2d_scale_init_factor

        self._version = version

        in_channels = [s.channels for s in input_shape]
        assert len(set(in_channels)) == 1, "Each level must have the same channel!"
        in_channels = in_channels[0]

        if use_deformable:
            raise ValueError("Not supported yet.")

        head_configs = {'cls': num_cls_convs, 'box2d': num_box_convs}

        for head_name, num_convs in head_configs.items():
            tower = []
            if self._version == "v1":
                for _ in range(num_convs):
                    conv_func = nn.Conv2d
                    tower.append(conv_func(in_channels, in_channels, kernel_size=3, stride=1, padding=1, bias=True))
                    if norm == "GN":
                        raise NotImplementedError()
                    elif norm == "NaiveGN":
                        raise NotImplementedError()
                    elif norm == "BN":
                        tower.append(ModuleListDial([nn.BatchNorm2d(in_channels) for _ in range(self.num_levels)]))
                    elif norm == "SyncBN":
                        raise NotImplementedError()
                    tower.append(nn.ReLU())
            elif self._version == "v2":
                for _ in range(num_convs):
                    if norm in ("BN", "FrozenBN", "SyncBN", "GN"):
                        # NOTE: need to add norm here!
                        # Each FPN level has its own batchnorm layer.
                        # NOTE: do not use dd3d train.py!
                        # "BN" is converted to "SyncBN" in distributed training (see train.py)
                        norm_layer = ModuleListDial([get_norm(norm, in_channels) for _ in range(self.num_levels)])
                    else:
                        norm_layer = get_norm(norm, in_channels)
                    tower.append(
                        Conv2d(
                            in_channels,
                            in_channels,
                            kernel_size=3,
                            stride=1,
                            padding=1,
                            bias=norm_layer is None,
                            norm=norm_layer,
                            activation=F.relu
                        )
                    )
            else:
                raise ValueError(f"Invalid FCOS2D version: {self._version}")
            self.add_module(f'{head_name}_tower', nn.Sequential(*tower))

        self.cls_logits = nn.Conv2d(in_channels, self.num_classes, kernel_size=3, stride=1, padding=1)
        self.box2d_reg = nn.Conv2d(in_channels, 4, kernel_size=3, stride=1, padding=1)
        self.centerness = nn.Conv2d(in_channels, 1, kernel_size=3, stride=1, padding=1)

        if self.use_scale:
            if self._version == "v1":
                self.scales_reg = nn.ModuleList([
                    Scale(init_value=stride * self.box2d_scale_init_factor) for stride in self.in_strides
                ])
            else:
                self.scales_box2d_reg = nn.ModuleList([
                    Scale(init_value=stride * self.box2d_scale_init_factor) for stride in self.in_strides
                ])

        self.init_weights()

    def init_weights(self):

        for tower in [self.cls_tower, self.box2d_tower]:
            for l in tower.modules():
                if isinstance(l, nn.Conv2d):
                    torch.nn.init.kaiming_normal_(l.weight, mode='fan_out', nonlinearity='relu')
                    if l.bias is not None:
                        torch.nn.init.constant_(l.bias, 0)

        predictors = [self.cls_logits, self.box2d_reg, self.centerness]

        for modules in predictors:
            for l in modules.modules():
                if isinstance(l, nn.Conv2d):
                    torch.nn.init.kaiming_uniform_(l.weight, a=1)
                    if l.bias is not None:  # depth head may not have bias.
                        torch.nn.init.constant_(l.bias, 0)

    def forward(self, x):
        logits = []
        box2d_reg = []
        centerness = []

        extra_output = {"cls_tower_out": []}

        for l, feature in enumerate(x):
            cls_tower_out = self.cls_tower(feature)
            bbox_tower_out = self.box2d_tower(feature)

            # 2D box
            logits.append(self.cls_logits(cls_tower_out))
            centerness.append(self.centerness(bbox_tower_out))
            box_reg = self.box2d_reg(bbox_tower_out)
            if self.use_scale:
                # TODO: to optimize the runtime, apply this scaling in inference (and loss compute) only on FG pixels?
                if self._version == "v1":
                    box_reg = self.scales_reg[l](box_reg)
                else:
                    box_reg = self.scales_box2d_reg[l](box_reg)
            # Note that we use relu, as in the improved FCOS, instead of exp.
            box2d_reg.append(F.relu(box_reg))

            extra_output['cls_tower_out'].append(cls_tower_out)

        return logits, box2d_reg, centerness, extra_output


class FCOS2DLoss(nn.Module):
    def __init__(self,
                 num_classes,
                 focal_loss_alpha=0.25,
                 focal_loss_gamma=2.0,
                 loc_loss_type='giou',
                 ):
        super().__init__()
        self.focal_loss_alpha = focal_loss_alpha
        self.focal_loss_gamma = focal_loss_gamma

        self.box2d_reg_loss_fn = IOULoss(loc_loss_type)

        self.num_classes = num_classes

    @force_fp32(apply_to=('logits', 'box2d_reg', 'centerness'))
    def forward(self, logits, box2d_reg, centerness, targets):
        labels = targets['labels']
        box2d_reg_targets = targets['box2d_reg_targets']
        pos_inds = targets["pos_inds"]

        if len(labels) != box2d_reg_targets.shape[0]:
            raise ValueError(
                f"The size of 'labels' and 'box2d_reg_targets' does not match: a={len(labels)}, b={box2d_reg_targets.shape[0]}"
            )

        # Flatten predictions
        logits = cat([x.permute(0, 2, 3, 1).reshape(-1, self.num_classes) for x in logits])
        box2d_reg_pred = cat([x.permute(0, 2, 3, 1).reshape(-1, 4) for x in box2d_reg])
        centerness_pred = cat([x.permute(0, 2, 3, 1).reshape(-1) for x in centerness])

        # -------------------
        # Classification loss
        # -------------------
        num_pos_local = pos_inds.numel()
        num_gpus = get_world_size()
        total_num_pos = reduce_sum(pos_inds.new_tensor([num_pos_local])).item()
        num_pos_avg = max(total_num_pos / num_gpus, 1.0)

        # prepare one_hot
        cls_target = torch.zeros_like(logits)
        cls_target[pos_inds, labels[pos_inds]] = 1

        loss_cls = sigmoid_focal_loss(
            logits,
            cls_target,
            alpha=self.focal_loss_alpha,
            gamma=self.focal_loss_gamma,
            reduction="sum",
        ) / num_pos_avg

        # NOTE: The rest of losses only consider foreground pixels.
        box2d_reg_pred = box2d_reg_pred[pos_inds]
        box2d_reg_targets = box2d_reg_targets[pos_inds]

        centerness_pred = centerness_pred[pos_inds]

        # Compute centerness targets here using 2D regression targets of foreground pixels.
        centerness_targets = compute_ctrness_targets(box2d_reg_targets)

        # Denominator for all foreground losses.
        ctrness_targets_sum = centerness_targets.sum()
        loss_denom = max(reduce_sum(ctrness_targets_sum).item() / num_gpus, 1e-6)

        # NOTE: change the return after reduce_sum
        if pos_inds.numel() == 0:
            losses = {
                "loss_cls": loss_cls,
                "loss_box2d_reg": box2d_reg_pred.sum() * 0.,
                "loss_centerness": centerness_pred.sum() * 0.,
            }
            return losses, {}

        # ----------------------
        # 2D box regression loss
        # ----------------------
        loss_box2d_reg = self.box2d_reg_loss_fn(box2d_reg_pred, box2d_reg_targets, centerness_targets) / loss_denom

        # ---------------
        # Centerness loss
        # ---------------
        loss_centerness = F.binary_cross_entropy_with_logits(
            centerness_pred, centerness_targets, reduction="sum"
        ) / num_pos_avg

        loss_dict = {"loss_cls": loss_cls, "loss_box2d_reg": loss_box2d_reg, "loss_centerness": loss_centerness}
        extra_info = {"loss_denom": loss_denom, "centerness_targets": centerness_targets}

        return loss_dict, extra_info


class FCOS2DInference():
    def __init__(self, cfg):
        self.thresh_with_ctr = cfg.DD3D.FCOS2D.INFERENCE.THRESH_WITH_CTR
        self.pre_nms_thresh = cfg.DD3D.FCOS2D.INFERENCE.PRE_NMS_THRESH
        self.pre_nms_topk = cfg.DD3D.FCOS2D.INFERENCE.PRE_NMS_TOPK
        self.post_nms_topk = cfg.DD3D.FCOS2D.INFERENCE.POST_NMS_TOPK
        self.nms_thresh = cfg.DD3D.FCOS2D.INFERENCE.NMS_THRESH
        self.num_classes = cfg.DD3D.NUM_CLASSES

    def __call__(self, logits, box2d_reg, centerness, locations, image_sizes):

        pred_instances = []  # List[List[Instances]], shape = (L, B)
        extra_info = []
        for lvl, (logits_lvl, box2d_reg_lvl, centerness_lvl, locations_lvl) in \
            enumerate(zip(logits, box2d_reg, centerness, locations)):

            instances_per_lvl, extra_info_per_lvl = self.forward_for_single_feature_map(
                logits_lvl, box2d_reg_lvl, centerness_lvl, locations_lvl, image_sizes
            )  # List of Instances; one for each image.

            for instances_per_im in instances_per_lvl:
                instances_per_im.fpn_levels = locations_lvl.new_ones(len(instances_per_im), dtype=torch.long) * lvl

            pred_instances.append(instances_per_lvl)
            extra_info.append(extra_info_per_lvl)

        return pred_instances, extra_info

    def forward_for_single_feature_map(self, logits, box2d_reg, centerness, locations, image_sizes):
        N, C, _, __ = logits.shape

        # put in the same format as locations
        scores = logits.permute(0, 2, 3, 1).reshape(N, -1, C).sigmoid()
        box2d_reg = box2d_reg.permute(0, 2, 3, 1).reshape(N, -1, 4)
        centerness = centerness.permute(0, 2, 3, 1).reshape(N, -1).sigmoid()

        # if self.thresh_with_ctr is True, we multiply the classification
        # scores with centerness scores before applying the threshold.
        if self.thresh_with_ctr:
            scores = scores * centerness[:, :, None]

        candidate_mask = scores > self.pre_nms_thresh

        pre_nms_topk = candidate_mask.reshape(N, -1).sum(1)
        pre_nms_topk = pre_nms_topk.clamp(max=self.pre_nms_topk)

        if not self.thresh_with_ctr:
            scores = scores * centerness[:, :, None]

        results = []
        all_fg_inds_per_im, all_topk_indices, all_class_inds_per_im = [], [], []
        for i in range(N):
            scores_per_im = scores[i]
            candidate_mask_per_im = candidate_mask[i]
            scores_per_im = scores_per_im[candidate_mask_per_im]

            candidate_inds_per_im = candidate_mask_per_im.nonzero(as_tuple=False)
            fg_inds_per_im = candidate_inds_per_im[:, 0]
            class_inds_per_im = candidate_inds_per_im[:, 1]

            # Cache info here.
            all_fg_inds_per_im.append(fg_inds_per_im)
            all_class_inds_per_im.append(class_inds_per_im)

            box2d_reg_per_im = box2d_reg[i][fg_inds_per_im]
            locations_per_im = locations[fg_inds_per_im]

            pre_nms_topk_per_im = pre_nms_topk[i]

            if candidate_mask_per_im.sum().item() > pre_nms_topk_per_im.item():
                scores_per_im, topk_indices = \
                    scores_per_im.topk(pre_nms_topk_per_im, sorted=False)

                class_inds_per_im = class_inds_per_im[topk_indices]
                box2d_reg_per_im = box2d_reg_per_im[topk_indices]
                locations_per_im = locations_per_im[topk_indices]
            else:
                topk_indices = None

            all_topk_indices.append(topk_indices)

            detections = torch.stack([
                locations_per_im[:, 0] - box2d_reg_per_im[:, 0],
                locations_per_im[:, 1] - box2d_reg_per_im[:, 1],
                locations_per_im[:, 0] + box2d_reg_per_im[:, 2],
                locations_per_im[:, 1] + box2d_reg_per_im[:, 3],
            ],
                                     dim=1)

            instances = Instances(image_sizes[i])
            instances.pred_boxes = Boxes(detections)
            instances.scores = torch.sqrt(scores_per_im)
            instances.pred_classes = class_inds_per_im
            instances.locations = locations_per_im

            results.append(instances)

        extra_info = {
            "fg_inds_per_im": all_fg_inds_per_im,
            "class_inds_per_im": all_class_inds_per_im,
            "topk_indices": all_topk_indices
        }
        return results, extra_info

    def nms_and_top_k(self, instances_per_im, score_key_for_nms="scores"):
        results = []
        for instances in instances_per_im:
            if self.nms_thresh > 0:
                # Multiclass NMS.
                keep = batched_nms(
                    instances.pred_boxes.tensor, instances.get(score_key_for_nms), instances.pred_classes,
                    self.nms_thresh
                )
                instances = instances[keep]
            num_detections = len(instances)

            # Limit to max_per_image detections **over all classes**
            if num_detections > self.post_nms_topk > 0:
                scores = instances.scores
                # image_thresh, _ = torch.kthvalue(scores.cpu(), num_detections - self.post_nms_topk + 1)
                image_thresh, _ = torch.kthvalue(scores, num_detections - self.post_nms_topk + 1)
                keep = scores >= image_thresh.item()
                keep = torch.nonzero(keep).squeeze(1)
                instances = instances[keep]
            results.append(instances)
        return results
