from abc import ABC
from typing import Any, Dict

import torch
from torch import Tensor, nn
from torch.nn.functional import binary_cross_entropy_with_logits

TensorDict = Dict[str, Tensor]


class PuLoss(nn.Module, ABC):
    """
    Abstract base class for a PU loss
    """

    def forward(self, phat: Tensor, y: Tensor, p: Tensor, l: Tensor, w: Tensor,
                **kwargs: Any) -> TensorDict:
        """
        Compute the loss using
         - phat: the predictions (logits)
         - y: the pseudo labels
         - p: True for samples that were the original positives
         - l: True for samples that are pseudo-labeled
         - w: scalar weight to apply to the loss of each sample
         - kwargs: other (possibly) irrelevant stuff into a batch
        """


class EmiLoss(PuLoss):
    def __init__(self, weighted: bool = True, pos_weight: float = 0.5,
                 unlab_weight: float = 0.5, temp: float = 1.0) -> None:
        super().__init__()
        self.weighted = weighted
        self.pos_weight = pos_weight
        self.unlab_weight = unlab_weight
        self.temp = temp

    def _get_weights(self, y: Tensor, is_unlab: Tensor, weighted: bool) -> Tensor:
        if weighted:
            is_unlab = is_unlab > 0.5
            return (
                is_unlab * self.unlab_weight
                + (~is_unlab & (y >= 0.5)) * self.pos_weight
                + (~is_unlab & (y < 0.5)).float()
            )
        else:
            return torch.ones_like(y)

    def forward(self, phat: Tensor, y: Tensor, p: Tensor, l: Tensor, w: Tensor,
                **kwargs: Any) -> TensorDict:
        u = ~p & ~l
        ws = self._get_weights(y=y, is_unlab=u, weighted=self.weighted)
        loss = binary_cross_entropy_with_logits(
            input=phat / self.temp,
            target=y.float(),
            weight=ws,
            reduction='mean'
        )

        return {'loss': loss}


# code inspired by
# https://github.com/kiryor/nnPUlearning
# https://github.com/cimeister/pu-learning
def _lsig(y: Tensor, phat: Tensor) -> Tensor:
    # sigmoid loss trafo
    # important: pytorch sigmoid implementation
    # torch.sidgmoid(x) = 1 / (1 + np.exp(-x))
    # hence we multiply by -1
    return torch.sigmoid(-1 * phat * torch.where(y > 0.5, 1, -1))


class BCEWithLogitsLoss(PuLoss):
    def __init__(self, pos_weight: float = 1.0) -> None:
        super().__init__()
        self.pos_weight = torch.tensor(pos_weight)

    def forward(self, phat: Tensor, y: Tensor, p: Tensor, l: Tensor, w: Tensor,
                **kwargs: Any) -> TensorDict:

        loss = binary_cross_entropy_with_logits(
            phat, y, pos_weight=self.pos_weight  # type: ignore[arg-type]
        )
        return {'loss': loss}


class uPU(PuLoss):
    """
    Implementation of uPU following Kiryo 2017
    """
    def __init__(self, prior: float = 0.5) -> None:
        super().__init__()
        self.prior = prior
        self.lsig = _lsig

    def forward(self, phat: Tensor, y: Tensor, p: Tensor, l: Tensor, w: Tensor,
                **kwargs: Any) -> TensorDict:
        u = ~p & ~l

        # separate data in positive (+1) and unlabelled (-1) samples based on u
        y_p = y[u == 0]
        phat_p = phat[u == 0]
        y_u = torch.full_like(input=y[u == 1], fill_value=-1)
        phat_u = phat[u == 1]

        # calculate the three losses
        R_p_plus = R_p_minus = R_u_minus = torch.tensor(0.0)  # pylint: disable=not-callable
        if len(y_p) > 0:
            R_p_plus = torch.mean(self.lsig(y=y_p, phat=phat_p))
            R_p_minus = torch.mean(self.lsig(y=y_p, phat=-phat_p))
        if len(y_u) > 0:
            R_u_minus = torch.mean(self.lsig(y=y_u, phat=phat_u))

        # eq 2 from kiryo2017, preparation for nnPU
        positive_risk = self.prior * R_p_plus
        negative_risk = R_u_minus - self.prior * R_p_minus
        R_pu = positive_risk + negative_risk

        return {'loss': R_pu}


class nnPU(PuLoss):
    """
    Implementation of nnPU following Kiryo 2017
    """
    def __init__(self, prior: float = 0.5, beta: float = 0.0) -> None:
        super().__init__()
        self.prior = prior
        self.lsig = _lsig
        self.beta = beta

    def forward(self, phat: Tensor, y: Tensor, p: Tensor, l: Tensor, w: Tensor,
                **kwargs: Any) -> TensorDict:
        u = ~p & ~l

        # separate data in positive (+1) and unlabelled (-1) samples based on u
        y_p = y[u == 0]
        phat_p = phat[u == 0]
        w_p = w[u == 0]
        y_u = torch.full_like(input=y[u == 1], fill_value=-1)
        phat_u = phat[u == 1]
        w_u = w[u == 1]

        # calculate the three losses
        R_p_plus = R_p_minus = R_u_minus = torch.tensor(0.0)
        if len(y_p) > 0:
            R_p_plus = torch.mean(w_p * self.lsig(y=y_p, phat=phat_p))
            R_p_minus = torch.mean(w_p * self.lsig(y=y_p, phat=-phat_p))
        if len(y_u) > 0:
            R_u_minus = torch.mean(w_u * self.lsig(y=y_u, phat=phat_u))

        # preparation for nnPU
        positive_risk = self.prior * R_p_plus
        negative_risk = R_u_minus - self.prior * R_p_minus
        R_pu = positive_risk + negative_risk

        # nnPU constraint from kiryo2017
        correction = negative_risk < -self.beta
        if not correction:
            loss = R_pu
        else:
            loss = -1 * negative_risk

        return {'loss': loss, 'correction': correction}


class nnPUSBloss(PuLoss):
    """
    Loss function for PUSB learning, adapted from
    https://github.com/Scottdyt/nnPUSB/blob/master/nnPU_loss.py
    """

    def __init__(self, prior: float, gamma: float = 1.0, beta: float = 0.0):
        super().__init__()
        if not 0 < prior < 1:
            raise ValueError("The class prior should be in (0, 1)")

        self.prior = prior
        self.gamma = gamma
        self.beta = beta
        self.eps = 1e-7

    def forward(self, phat: Tensor, y: Tensor, p: Tensor, l: Tensor, w: Tensor,
                **kwargs: Any) -> TensorDict:

        # clip the predict value to make the following optimization problem well-defined.
        yhat = torch.clamp(torch.sigmoid(phat),
                           min=self.eps, max=1 - self.eps)

        positive, unlabeled = p, ~l
        if not torch.all(positive | unlabeled):
            raise RuntimeError('loss is only defined for PU data, got something labeled')

        n_positive = max(1., positive.sum().item())
        n_unlabeled = max(1., unlabeled.sum().item())

        f_positive = -torch.log(yhat)
        f_unlabeled = -torch.log(1 - yhat)

        positive_risk = torch.sum(self.prior * positive / n_positive * f_positive)
        negative_risk = torch.sum(
            (unlabeled / n_unlabeled - self.prior * positive / n_positive) * f_unlabeled
        )

        objective = positive_risk + negative_risk

        # nnPU learning
        if negative_risk.item() < -self.beta:
            objective = positive_risk - self.beta
            loss = -self.gamma * negative_risk
        else:
            loss = objective

        return {
            'loss': loss,
            'positive_risk': positive_risk,
            'negative_risk': negative_risk
        }


class SelfPu(PuLoss):
    """
    Self-PU-inspired loss, which uses cross-entropy for pseudo-labeled samples
    and nnPU loss for the original positives and the remaining unlabeled samples
    """

    def __init__(self, prior: float = 0.5, beta: float = 0.0, xe_weight: float = 0.5):
        super().__init__()
        self._nnpu = nnPU(prior, beta)
        self.xe_weight = xe_weight

    def forward(self, phat: Tensor, y: Tensor, p: Tensor, l: Tensor, w: Tensor,
                **kwargs: Any) -> TensorDict:
        pu = p | ~l
        nnpu_loss = torch.tensor(0.0)
        if pu.any():
            nnpu_loss = self._nnpu(phat[pu], y[pu], p[pu], l[pu], w[pu])['loss']
            assert torch.isfinite(nnpu_loss)

        xe_loss = torch.tensor(0.0)
        if not pu.all():
            xe_loss = binary_cross_entropy_with_logits(phat[~pu], y[~pu], w[~pu])
            assert torch.isfinite(xe_loss)

        loss = self.xe_weight * xe_loss + (1 - self.xe_weight) * nnpu_loss
        return {'loss': loss, 'nnpu_loss': nnpu_loss, 'xe_loss': xe_loss}


class CombinedLoss(PuLoss):
    def __init__(self, pseudo_labeled_loss_weight: float,
                 pseudo_labeled_loss: Dict[str, Any],
                 positive_unlabeled_loss: Dict[str, Any]):

        super().__init__()
        self.pl_loss = get_loss(pseudo_labeled_loss)
        self.pu_loss = get_loss(positive_unlabeled_loss)
        self.weight = pseudo_labeled_loss_weight

    def forward(self, phat: Tensor, y: Tensor, p: Tensor, l: Tensor, w: Tensor,
                **kwargs: Any) -> TensorDict:

        pu = p | ~l
        pu_loss: TensorDict = {'loss': torch.tensor(0.0)}
        if pu.any():
            pu_loss = self.pu_loss(phat[pu], y[pu], p[pu], l[pu], w[pu])
            assert torch.isfinite(pu_loss['loss'])

        pl_loss: TensorDict = {'loss': torch.tensor(0.0)}
        if not pu.all():
            pl_loss = self.pl_loss(phat[~pu], y[~pu], w[~pu], l[~pu], w[~pu])
            assert torch.isfinite(pl_loss['loss'])

        loss = self.weight * pl_loss['loss'] + (1 - self.weight) * pu_loss['loss']
        return {
            'loss': loss,
            **{f'pu_{k}': v for k, v in pu_loss.items()},
            **{f'pl_{k}': v for k, v in pl_loss.items()},
        }


class ECELoss(nn.Module):
    """
    Calculates the Expected Calibration Error of a model.
    (This isn't necessary for temperature scaling, just a cool metric).
    The input to this loss is the logits of a model, NOT the softmax scores.
    This divides the confidence outputs into equally-sized interval bins.
    In each bin, we compute the confidence gap:
    bin_gap = | avg_confidence_in_bin - accuracy_in_bin |
    We then return a weighted average of the gaps, based on the number
    of samples in each bin
    See: Naeini, Mahdi Pakdaman, Gregory F. Cooper, and Milos Hauskrecht.
    "Obtaining Well Calibrated Probabilities Using Bayesian Binning." AAAI.
    2015.
    """
    def __init__(self, n_bins: int = 30, order: int = 1) -> None:
        """
        n_bins (int): number of confidence interval bins
        """
        super().__init__()
        bin_boundaries = torch.linspace(0, 1, n_bins + 1)
        self.bin_lowers = bin_boundaries[:-1]
        self.bin_uppers = bin_boundaries[1:]
        self.order = order

    def forward(self, logits: Tensor, labels: Tensor) -> Tensor:
        sigmoids = torch.sigmoid(logits)
        sigmoids = torch.stack([1 - sigmoids, sigmoids], dim=1)
        confidences, predictions = torch.max(sigmoids, 1)
        accuracies = predictions.eq(labels)

        ece = torch.zeros(1, device=logits.device)
        for bin_lower, bin_upper in zip(self.bin_lowers, self.bin_uppers):
            # Calculated |confidence - accuracy| in each bin
            in_bin = confidences.gt(bin_lower.item()) * confidences.le(bin_upper.item())
            prop_in_bin = in_bin.float().mean()
            if prop_in_bin.item() > 0:
                accuracy_in_bin = accuracies[in_bin].float().mean()
                avg_confidence_in_bin = confidences[in_bin].mean()
                ece += prop_in_bin * torch.abs(
                    avg_confidence_in_bin - accuracy_in_bin
                )**self.order

        return ece


losses_dict = {
    'xe': EmiLoss,
    'nnpu': nnPU,
    'upu': uPU,
    'selfpu': SelfPu,
    'combined_loss': CombinedLoss,
    'binary_cross_entropy': BCEWithLogitsLoss,
    'nn_pusb': nnPUSBloss,
    'nnpusb': nnPUSBloss,
}


def get_loss(config: Dict[str, Any]) -> PuLoss:
    cls = losses_dict[config.pop('class').lower()]
    params = config.get('params') or config
    return cls(**params)
