""" Loss Function for Self-Ensembling Semi-Supervised 3D Object Detection
Author: Zhao Na, 2019
"""

import torch
import torch.nn.functional as F

from pcdet.utils.nn_distance import nn_distance, huber_loss
from pcdet.ops.iou3d_nms import iou3d_nms_utils

from pcdet.models.dense_heads.anchor_head_template import AnchorHeadTemplate


def compute_center_consistency_loss(pred_dict, ema_pred_dict):

    boxes = pred_dict["pred_boxes"].unsqueeze(0)  # (B, num_proposal, 3)
    center = boxes[:, :, :3]  # (B, num_proposal, 3)
    size = boxes[:, :, 3:6]  # (B, num_proposal, 3)
    radius = torch.sum(torch.pow(size, 2), dim=-1)
    radius = torch.sqrt(radius) / 2

    ema_boxes = ema_pred_dict["pred_boxes"].unsqueeze(0)  # (B, num_proposal, 3)
    ema_center = ema_boxes[:, :, :3]  # (B, num_proposal, 3)
    ema_size = ema_boxes[:, :, 3:6]  # (B, num_proposal, 3)
    ema_radius = torch.sum(torch.pow(ema_size, 2), dim=-1)
    ema_radius = torch.sqrt(ema_radius) / 2

    dist1, ind1, dist2, ind2 = nn_distance(center, ema_center, l1smooth=True)
    # ind1 (B, num_proposal): ema_center index closest to center

    radius_aligned = torch.cat(
        [torch.index_select(a, 0, i).unsqueeze(0) for a, i in zip(radius, ind2)]
    )

    distance = radius_aligned + ema_radius
    dist2 = dist2 / distance

    return torch.mean(dist2), ind2


def compute_class_consistency_loss(pred_dict, ema_pred_dict, map_ind):
    # cls_scores = pred_dict["pred_raw_scores"].unsqueeze(0) # (B, num_proposal, num_class)
    # ema_cls_scores = ema_pred_dict["pred_raw_scores"].unsqueeze(0) # (B, num_proposal, num_class)

    # cls_log_prob = F.log_softmax(cls_scores, dim=2)  # (B, num_proposal, num_class)
    # ema_cls_prob = F.softmax(ema_cls_scores, dim=2)  # (B, num_proposal, num_class)

    cls_scores = pred_dict["pred_scores"].unsqueeze(0)  # (B, num_proposal, num_class)
    ema_cls_scores = ema_pred_dict["pred_scores"].unsqueeze(
        0
    )  # (B, num_proposal, num_class)

    cls_scores_aligned = torch.cat(
        [torch.index_select(a, 0, i).unsqueeze(0) for a, i in zip(cls_scores, map_ind)]
    )

    # class_consistency_loss = F.kl_div(cls_log_prob_aligned, ema_cls_prob)
    # class_consistency_loss = F.smooth_l1_loss(cls_scores_aligned, ema_cls_scores)
    class_consistency_loss = F.mse_loss(cls_scores_aligned, ema_cls_scores)
    return class_consistency_loss


def compute_size_consistency_loss(pred_dict, ema_pred_dict, map_ind):

    size = pred_dict["pred_boxes"].unsqueeze(0)  # (B, num_proposal, 3)
    ema_size = ema_pred_dict["pred_boxes"].unsqueeze(0)  # (B, num_proposal, 3)

    size_aligned = torch.cat(
        [torch.index_select(a, 0, i).unsqueeze(0) for a, i in zip(size, map_ind)]
    )

    size_aligned, ema_size = AnchorHeadTemplate.add_sin_difference(
        size_aligned, ema_size
    )

    size_aligned = size_aligned[:, :, 3:7]  # (B, num_proposal, 3)
    ema_size = ema_size[:, :, 3:7]  # (B, num_proposal, 3)

    # size_consistency_loss = F.mse_loss(size_aligned, ema_size)
    size_consistency_loss = F.smooth_l1_loss(size_aligned, ema_size)
    # size_consistency_loss = huber_loss(size_aligned, ema_size)

    return size_consistency_loss


def get_consistency_loss(pred_dicts, ema_pred_dicts):
    """
    Args:
        end_points: dict
            {
                center, size_scores, size_residuals_normalized, sem_cls_scores,
                flip_x_axis, flip_y_axis, rot_mat
            }
        ema_end_points: dict
            {
                center, size_scores, size_residuals_normalized, sem_cls_scores,
            }
    Returns:
        consistency_loss: pytorch scalar tensor
        end_points: dict
    """

    center_consistency_loss = torch.tensor(
        0.0,
        requires_grad=False,
        dtype=torch.float32,
        device=pred_dicts[0]["pred_boxes"].device,
    )
    class_consistency_loss = torch.tensor(
        0.0,
        requires_grad=False,
        dtype=torch.float32,
        device=pred_dicts[0]["pred_boxes"].device,
    )
    size_consistency_loss = torch.tensor(
        0.0,
        requires_grad=False,
        dtype=torch.float32,
        device=pred_dicts[0]["pred_boxes"].device,
    )

    for i in range(len(pred_dicts)):
        if (
            ema_pred_dicts[i]["pred_boxes"].shape[0] == 0
            or pred_dicts[i]["pred_boxes"].shape[0] == 0
        ):
            if ema_pred_dicts[i]["pred_boxes"].shape[0] == 0:
                print("ema_empty")
            if pred_dicts[i]["pred_boxes"].shape[0] == 0:
                print("student_empty")

            continue
        center_consistency_loss, map_ind = compute_center_consistency_loss(
            pred_dicts[i], ema_pred_dicts[i]
        )
        class_consistency_loss = compute_class_consistency_loss(
            pred_dicts[i], ema_pred_dicts[i], map_ind
        )
        size_consistency_loss = compute_size_consistency_loss(
            pred_dicts[i], ema_pred_dicts[i], map_ind
        )
        center_consistency_loss += center_consistency_loss
        class_consistency_loss += class_consistency_loss
        size_consistency_loss += size_consistency_loss

    center_consistency_loss /= len(pred_dicts)
    class_consistency_loss /= len(pred_dicts)
    size_consistency_loss /= len(pred_dicts)

    consistency_loss = (
        center_consistency_loss + size_consistency_loss + class_consistency_loss
    )

    ret_dict = {
        "center_con_loss": center_consistency_loss,
        "class_con_loss": class_consistency_loss,
        "size_con_loss": size_consistency_loss,
        "con_loss": consistency_loss,
    }

    return ret_dict
