import torch
import torch.nn as nn
import torch.nn.functional as F
from .lovasz_softmax import lovasz_softmax


class BEV_occ_loss(nn.Module):

    def __init__(self):
        super().__init__()

    def forward(self, pred_occ, sampled_label, occ_mask=None, lovasz_loss_weight=0.1):

        tot_loss = 0.0

        if occ_mask is not None:
            sampled_label = sampled_label[:, occ_mask]

        if occ_mask is not None:
            pred_occ = pred_occ.transpose(1, 2)[:, occ_mask].transpose(1, 2)  # 1, c, n

        # semantics = semantics.transpose(1, 2)
        tot_loss += CE_wo_softmax(
            pred_occ,
            sampled_label.long(),
            ignore_index=255,
        )

        lovasz_input = pred_occ

        tot_loss += lovasz_loss_weight * lovasz_softmax(
            lovasz_input.transpose(1, 2).flatten(0, 1),
            sampled_label.flatten(),
            ignore=0,
        )

        return tot_loss


def CE_wo_softmax(pred, target, class_weights=None, ignore_index=255):
    pred = torch.clamp(pred, 1e-6, 1.0 - 1e-6)
    loss = F.nll_loss(torch.log(pred), target, class_weights, ignore_index=ignore_index)
    return loss
