import torch
import torch.nn.functional as F
from torch import nn


class BBLossFunction(nn.BCEWithLogitsLoss):

    def __init__(
        self,
        weight=None,
        size_average=None,
        reduce=None,
        reduction="none",
        pos_weight=None,
    ):
        super().__init__(weight, size_average, reduce, reduction, pos_weight)

    def forward(self, pred, label, **kwargs):
        loss = super().forward(pred, label)
        return loss


class BBRankingLossFunction(nn.MarginRankingLoss):

    def __init__(self, margin=0, size_average=None, reduce=None, reduction="mean"):
        super().__init__(margin, size_average, reduce, reduction)

    def forward(self, pred, label, **kwargs):
        pred1, pred2 = pred
        return super().forward(pred1, pred2, label)


class BBListRankingLossFunction(nn.CrossEntropyLoss):

    def __init__(
        self,
        weight=None,
        size_average=None,
        ignore_index=-100,
        reduce=None,
        reduction="none",
        label_smoothing=0,
    ):
        super().__init__(
            weight, size_average, ignore_index, reduce, reduction, label_smoothing
        )

    def forward(self, pred, scoring_label, scoring_mask, **kwargs):
        _, decimation_pred = pred
        log_probs = torch.log_softmax(decimation_pred, dim=1)
        nll = -(log_probs * scoring_label * scoring_mask).sum(dim=-1)
        valid_counts = scoring_mask.sum(dim=-1) + 1
        return (nll / valid_counts).mean()


class BBListRankingLossFunctionV2(nn.KLDivLoss):

    def __init__(
        self, size_average=None, reduce=None, reduction="mean", log_target=False
    ):
        super().__init__(size_average, reduce, reduction, log_target)

    def forward(self, pred, scoring_label, scoring_mask, **kwargs):
        _, decimation_pred = pred
        log_probs = torch.log_softmax(decimation_pred, dim=1)
        masked_kl_div = super().forward(log_probs, scoring_label) * scoring_mask
        loss = masked_kl_div.sum(dim=1) / scoring_mask.sum(dim=1).clamp(min=1)
        return loss.sum() / torch.any(scoring_mask != 0, dim=1).sum().clamp(min=1)


class BBListMLClassficationLossFunction(nn.BCEWithLogitsLoss):

    def __init__(
        self,
        weight=None,
        size_average=None,
        reduce=None,
        reduction="mean",
        pos_weight=None,
    ):
        super().__init__(weight, size_average, reduce, reduction, pos_weight)

    def forward(self, pred, ml_class_label, ml_class_mask, **kwargs):
        optimality_pred, _ = pred
        ul = super().forward(optimality_pred, ml_class_label)
        return (ul * ml_class_mask).mean()


class BBListBCELossFunction(nn.BCEWithLogitsLoss):

    def __init__(
        self,
        weight=None,
        size_average=None,
        reduce=None,
        reduction="mean",
        pos_weight=None,
    ):
        super().__init__(weight, size_average, reduce, reduction, pos_weight)

    def forward(self, pred, label, mask, **kwargs):
        _, decimation_pred = pred
        log_probs = F.logsigmoid(decimation_pred)
        nll = -(log_probs * label * mask).sum(dim=-1)
        valid_counts = mask.sum(dim=-1)
        return (nll / valid_counts).mean()


class BBMSELoss(nn.MSELoss):

    def __init__(self, size_average=None, reduce=None, reduction="mean"):
        super().__init__(size_average, reduce, reduction)

    def forward(self, pred, child_nodes, **kwargs):
        return torch.sqrt(super().forward(pred, child_nodes))


class BBListMLELoss(nn.Module):

    def __init__(self, *args, **kwargs):
        super().__init__(*args, **kwargs)

    def forward(self, pred, scoring_indices, scoring_mask, **kwargs):
        _, decimation_pred = pred
        # import pdb

        # pdb.set_trace()
        decimation_pred[scoring_mask == 0] = -float("inf")
        pred_sorted = torch.gather(decimation_pred, 1, scoring_indices)
        mask_sorted = torch.gather(scoring_mask, 1, scoring_indices)

        log_cumulative_sum_exp = torch.zeros_like(pred_sorted)
        list_length = log_cumulative_sum_exp.size(1)
        for i in range(list_length):
            # For each position i, consider all scores from i to end.
            # Since invalid positions are already -inf, they won't contribute.
            current_slice = pred_sorted[:, i:]
            log_cumulative_sum_exp[:, i] = torch.logsumexp(current_slice, dim=1)

        log_cumulative_sum_exp = torch.where(
            log_cumulative_sum_exp > 0, log_cumulative_sum_exp, 0
        )
        masked_loss_terms = log_cumulative_sum_exp * mask_sorted
        listmle_loss = torch.sum(masked_loss_terms, dim=1) / mask_sorted.sum(
            dim=1
        ).clamp(min=1)
        return listmle_loss.sum() / torch.any(scoring_mask != 0, dim=1).sum()
