from typing import Dict

import torch
import torch.nn.functional as F
from navsim.agents.gaussianfusion.transfuser_config import TransfuserConfig
from navsim.agents.gaussianfusion.transfuser_features import BoundingBox2DIndex
from navsim.agents.gaussianfusion.modules.lovasz_softmax import lovasz_softmax


def transfuser_loss(
    targets: Dict[str, torch.Tensor],
    predictions: Dict[str, torch.Tensor],
    config: TransfuserConfig,
):
    """
    Helper function calculating complete loss of Transfuser
    :param targets: dictionary of name tensor pairings
    :param predictions: dictionary of name tensor pairings
    :param config: global Transfuser config
    :return: combined loss value
    """

    if "trajectory_loss" in predictions:
        trajectory_loss = predictions["trajectory_loss"]
    else:
        trajectory_loss = F.l1_loss(predictions["trajectory"], targets["trajectory"])
 
    bev_semantic_loss = _bev_occ_loss(
        predictions["pred_bev_occ"], predictions["gt_bev_occ"]
    )

    loss = (
        config.trajectory_weight * trajectory_loss
        + config.bev_semantic_weight * bev_semantic_loss
    )
    loss_dict = {
        "loss": loss,
        "trajectory_loss": config.trajectory_weight * trajectory_loss,
        "bev_semantic_loss": config.bev_semantic_weight * bev_semantic_loss,
    }
    if "trajectory_loss_dict" in predictions:
        trajectory_loss_dict = predictions["trajectory_loss_dict"]
        loss_dict.update(trajectory_loss_dict)

    return loss_dict


def _bev_occ_loss(pred_occ, sampled_label, occ_mask=None, lovasz_loss_weight=0.1):

    tot_loss = 0.0

    if occ_mask is not None:
        occ_mask = occ_mask.flatten(1)
        sampled_label = sampled_label[occ_mask][None]

    for semantics in pred_occ:
        if occ_mask is not None:
            semantics = semantics.transpose(1, 2)[occ_mask][None].transpose(
                1, 2
            )  # 1, c, n
        loss_dict = {}

        # semantics = semantics.transpose(1, 2)
        loss_dict["loss_voxel_ce"] = CE_wo_softmax(
            semantics,
            sampled_label.long(),
            ignore_index=255,
        )

        lovasz_input = semantics

        loss_dict["loss_voxel_lovasz"] = lovasz_loss_weight * lovasz_softmax(
            lovasz_input.transpose(1, 2).flatten(0, 1),
            sampled_label.flatten(),
            ignore=0,
        )

        loss = 0.0
        for k, v in loss_dict.items():
            loss = loss + v
        tot_loss = tot_loss + loss
    return tot_loss / len(pred_occ)


def CE_wo_softmax(pred, target, class_weights=None, ignore_index=255):
    pred = torch.clamp(pred, 1e-6, 1.0 - 1e-6)
    loss = F.nll_loss(torch.log(pred), target, class_weights, ignore_index=ignore_index)
    return loss

