from typing import Dict

import torch
from torch import Tensor, nn


class OccupancyDensityLoss(nn.Module):
    def __init__(
        self, occ_weight: float = 1.0, field_weight: float = 1.0, field_loss_occ_only: bool = False
    ):
        """Loss for the OccupancyDensityAutoencoder. Assumes that non-occupancy class is at dim 0.

        :param occ_weight: Weight for the occupancy loss, defaults to 1.0
        :type occ_weight: float, optional
        :param field_weight: Weight for the field loss, defaults to 1.0
        :type field_weight: float, optional
        :param field_loss_occ_only: Calculate a field loss for non-occopancy class also, defaults
            to False
        :type field_loss_occ_only: bool, optional
        """
        super().__init__()
        self.occ_weight: float = occ_weight
        self.field: float = field_weight
        self.field_loss_occ_only: bool = field_loss_occ_only

        self.occ_loss_function: torch.nn.Module = nn.CrossEntropyLoss()
        self.field_loss_function: torch.nn.Module = nn.MSELoss()

    def forward(
        self, preds_occ: Tensor, preds_field: Tensor, occ: Tensor, field: Tensor
    ) -> Dict[str, Tensor]:
        occ_loss = self.occ_loss_function(preds_occ, occ)

        if self.field_loss_occ_only:
            mask = occ[:, 0] == 0
            field = field[mask]
            preds_field = preds_field[mask]
            field_loss = self.field_loss_function(preds_field, field)
        else:
            field_loss = self.field_loss_function(preds_field, field)

        loss = self.occ_weight * occ_loss + self.field * field_loss

        return {"loss": loss, "occ_loss": occ_loss, "field_loss": field_loss}
