""" Loss Function for Self-Ensembling Semi-Supervised 3D Object Detection
Author: Zhao Na, 2019
"""

import torch
import torch.nn.functional as F

from pcdet.utils.nn_distance import nn_distance, huber_loss
from pcdet.ops.iou3d_nms import iou3d_nms_utils

from pcdet.models.dense_heads.anchor_head_template import AnchorHeadTemplate


def compute_center_idx(pred_dict, aug_gt_pred_dict, gt_boxes):

    boxes = pred_dict["pred_boxes"].unsqueeze(0)  # (B, num_proposal, 3)
    center = boxes[:, :, :3]  # (B, num_proposal, 3)

    aug_gt_boxes = aug_gt_pred_dict["pred_boxes"].unsqueeze(0)  # (B, num_proposal, 3)
    aug_gt_center = aug_gt_boxes[:, :, :3]  # (B, num_proposal, 3)

    gt_boxes = gt_boxes.unsqueeze(0)  # (B, num_proposal, 3)
    gt_center = gt_boxes[:, :, :3]  # (B, num_proposal, 3)

    _, _, _, pred_idx = nn_distance(center, gt_center)
    _, _, _, aug_gt_pred_idx = nn_distance(aug_gt_center, gt_center)

    return pred_idx, aug_gt_pred_idx


def compute_center_consistency_loss(
    pred_dict, aug_gt_pred_dict, pred_idx, aug_gt_pred_idx
):

    boxes = pred_dict["pred_boxes"].unsqueeze(0)  # (B, num_proposal, 3)
    boxes = torch.cat(
        [torch.index_select(a, 0, i).unsqueeze(0) for a, i in zip(boxes, pred_idx)]
    )
    center = boxes[:, :, :3]  # (B, num_proposal, 3)
    size = boxes[:, :, 3:6]  # (B, num_proposal, 3)
    radius = torch.sum(torch.pow(size, 2), dim=-1)
    radius = torch.sqrt(radius) / 2

    aug_gt_boxes = aug_gt_pred_dict["pred_boxes"].unsqueeze(0)  # (B, num_proposal, 3)
    aug_gt_boxes = torch.cat(
        [
            torch.index_select(a, 0, i).unsqueeze(0)
            for a, i in zip(aug_gt_boxes, aug_gt_pred_idx)
        ]
    )
    aug_gt_center = aug_gt_boxes[:, :, :3]  # (B, num_proposal, 3)
    aug_gt_size = aug_gt_boxes[:, :, 3:6]  # (B, num_proposal, 3)
    aug_gt_radius = torch.sum(torch.pow(aug_gt_size, 2), dim=-1)
    aug_gt_radius = torch.sqrt(aug_gt_radius) / 2

    dist1, ind1, dist2, ind2 = nn_distance(center, aug_gt_center, l1smooth=True)
    # ind1 (B, num_proposal): ema_center index closest to center

    distance = radius + aug_gt_radius
    dist2 = dist2 / distance

    return torch.mean(dist2)


def compute_class_consistency_loss(
    pred_dict, aug_gt_pred_dict, pred_idx, aug_gt_pred_idx
):
    cls_scores = pred_dict["pred_scores"].unsqueeze(0)  # (B, num_proposal, num_class)
    cls_scores = torch.cat(
        [torch.index_select(a, 0, i).unsqueeze(0) for a, i in zip(cls_scores, pred_idx)]
    )

    aug_gt_cls_scores = aug_gt_pred_dict["pred_scores"].unsqueeze(
        0
    )  # (B, num_proposal, num_class)
    aug_gt_cls_scores = torch.cat(
        [
            torch.index_select(a, 0, i).unsqueeze(0)
            for a, i in zip(aug_gt_cls_scores, aug_gt_pred_idx)
        ]
    )

    # class_consistency_loss = F.kl_div(cls_log_prob_aligned, ema_cls_prob)
    # class_consistency_loss = F.smooth_l1_loss(cls_scores_aligned, ema_cls_scores)
    class_consistency_loss = F.mse_loss(cls_scores, aug_gt_cls_scores)
    return class_consistency_loss


def compute_size_consistency_loss(
    pred_dict, aug_gt_pred_dict, pred_idx, aug_gt_pred_idx, noise_rotation, noise_scale
):

    size = pred_dict["pred_boxes"].unsqueeze(0)  # (B, num_proposal, 3)
    size = torch.cat(
        [torch.index_select(a, 0, i).unsqueeze(0) for a, i in zip(size, pred_idx)]
    )
    size[:, :, 6] += noise_rotation
    size[:, :, 3:6] *= noise_scale

    aug_gt_size = aug_gt_pred_dict["pred_boxes"].unsqueeze(0)  # (B, num_proposal, 3)
    aug_gt_size = torch.cat(
        [
            torch.index_select(a, 0, i).unsqueeze(0)
            for a, i in zip(aug_gt_size, aug_gt_pred_idx)
        ]
    )

    size, aug_gt_size = AnchorHeadTemplate.add_sin_difference(size, aug_gt_size)
    size = size[:, :, 3:7]  # (B, num_proposal, 3)
    aug_gt_size = aug_gt_size[:, :, 3:7]  # (B, num_proposal, 3)

    # size_consistency_loss = F.mse_loss(size_aligned, ema_size)
    size_consistency_loss = F.smooth_l1_loss(size, aug_gt_size)
    # size_consistency_loss = huber_loss(size_aligned, ema_size)

    return size_consistency_loss


def get_consistency_loss(
    pred_dicts, aug_gt_pred_dicts, gt_boxes, noise_rotations, noise_scales
):

    center_consistency_loss = torch.tensor(
        0.0,
        requires_grad=False,
        dtype=torch.float32,
        device=pred_dicts[0]["pred_boxes"].device,
    )
    class_consistency_loss = torch.tensor(
        0.0,
        requires_grad=False,
        dtype=torch.float32,
        device=pred_dicts[0]["pred_boxes"].device,
    )
    size_consistency_loss = torch.tensor(
        0.0,
        requires_grad=False,
        dtype=torch.float32,
        device=pred_dicts[0]["pred_boxes"].device,
    )

    for i in range(len(pred_dicts)):

        pred_idx, aug_gt_pred_idx = compute_center_idx(
            pred_dicts[i], aug_gt_pred_dicts[i], gt_boxes[i]
        )

        noise_rotation = noise_rotations[i]
        noise_scale = noise_scales[i]

        center_consistency_loss = compute_center_consistency_loss(
            pred_dicts[i], aug_gt_pred_dicts[i], pred_idx, aug_gt_pred_idx
        )
        class_consistency_loss = compute_class_consistency_loss(
            pred_dicts[i], aug_gt_pred_dicts[i], pred_idx, aug_gt_pred_idx
        )
        size_consistency_loss = compute_size_consistency_loss(
            pred_dicts[i],
            aug_gt_pred_dicts[i],
            pred_idx,
            aug_gt_pred_idx,
            noise_rotation,
            noise_scale,
        )
        center_consistency_loss += center_consistency_loss
        class_consistency_loss += class_consistency_loss
        size_consistency_loss += size_consistency_loss

    center_consistency_loss /= len(pred_dicts)
    class_consistency_loss /= len(pred_dicts)
    size_consistency_loss /= len(pred_dicts)

    consistency_loss = (
        center_consistency_loss + size_consistency_loss + class_consistency_loss
    )

    ret_dict = {
        "center_con_loss": center_consistency_loss,
        "class_con_loss": class_consistency_loss,
        "size_con_loss": size_consistency_loss,
        "con_loss": consistency_loss,
    }

    return ret_dict
