from abc import abstractmethod

import torch
from torch import Tensor, nn


class CustomLoss(nn.Module):
    """
    Base class for custom loss functions.
    """

    # Supported reduction methods
    reduction_methods: dict[str, str] = {
        "mean": "mean",
        "sum": "sum",
        "none": "none",
    }

    def __init__(self, reduction: str = "mean") -> None:
        super().__init__()
        assert reduction in self.reduction_methods, (
            f"Reduction method {reduction} not supported."
        )
        self.reduction = reduction

    def reduce(self, loss: Tensor) -> Tensor:
        """
        Reduce the loss according to the specified reduction method.

        Args:
            loss (Tensor): The loss tensor to be reduced.

        Returns:
            Tensor: The reduced loss tensor.
        """
        if self.reduction == "mean":
            return loss.mean()
        if self.reduction == "sum":
            return loss.sum()
        return loss

    @abstractmethod
    def forward(self, predictions: Tensor, targets: Tensor) -> Tensor:
        """
        Compute the loss between the predictions and targets.

        Args:
            predictions (Tensor): The predicted values.
            targets (Tensor): The target values.

        Returns:
            Tensor: The computed loss.
        """


class GaussianCustomLoss(CustomLoss):
    """
    Base class for custom Gaussian loss functions. If std_p or std_t are not
    specified, the predictions and targets are assumed to be the concatenation
    of the tensors of the mean and standard deviation of the Gaussian
    distributions. If std_p and std_t are specified, the predictions and
    targets are assumed to be the means of the Gaussian distributions.
    """

    def __init__(
        self,
        std_p: float | None = None,
        std_t: float | None = None,
        log: bool = False,
        reduction: str = "mean",
    ) -> None:
        super().__init__(reduction=reduction)
        self.set_stds(std_p, std_t)
        self.log = log

    def set_stds(self, std_p: float | None = None, std_t: float | None = None) -> None:
        """Set the standard deviations of the Gaussian distributions.

        Args:
            std_p (float | None): The standard deviation of the predictions.
            std_t (float | None): The standard deviation of the targets.
        """
        assert std_p is None or std_p != 0, (
            "Standard deviation of predictions must not be zero."
        )
        assert std_t is None or std_t != 0, (
            "Standard deviation of targets must not be zero."
        )
        assert std_p is None or std_p != std_t, (
            "If STDs are the same use the SameSTD version of the loss."
        )
        self.std_p = std_p
        self.std_t = std_t

    def _extract_mus_and_sigmas(
        self, predictions: Tensor, targets: Tensor
    ) -> tuple[Tensor, Tensor, Tensor, Tensor]:
        """
        If std_p and std_t are not specified, the mus are the first half of the
        tensors and the sigmas are the second half of the tensors. If std_p and
        std_t are specified, the mus are the tensors in input and the sigmas are
        the ones specified in the constructor. If log is True, the sigmas
        are exponentiated.

        Args:
            predictions (Tensor): the predictions tensor.
            targets (Tensor): the targets tensor.

        Returns:
            tuple[Tensor, Tensor, Tensor, Tensor]: the means and standard
            deviations of the predictions and targets in the order
            (mu_p, sigma_p, mu_t, sigma_t).
        """
        if self.std_p is None:
            mu_p, sigma_p = predictions.chunk(2, dim=-1)
        else:
            mu_p = predictions
            sigma_p = torch.tensor([self.std_p], dtype=predictions.dtype)
        if self.std_t is None:
            mu_t, sigma_t = targets.chunk(2, dim=-1)
        else:
            mu_t = targets
            sigma_t = torch.tensor([self.std_t], dtype=targets.dtype)
        if self.log:
            sigma_p = sigma_p.exp()
            sigma_t = sigma_t.exp()

        return mu_p, sigma_p, mu_t, sigma_t


class GaussianRenyiAlphaDivLoss(GaussianCustomLoss):
    """
    Compute the Renyi-alpha divergence loss between two Gaussian distributions,
    computed as Eq. 10 in https://arxiv.org/pdf/1206.2459.
    If std_p or std_t are not specified, the predictions and targets are
    assumed to be the concatenation of the tensors of the mean and standard
    deviation of the Gaussian distributions. If std_p and std_t are specified,
    the predictions and targets are assumed to be the means of the Gaussian
    distributions.
    """

    def __init__(
        self,
        alpha: float,
        std_p: float | None = None,
        std_t: float | None = None,
        log: bool = False,
        reduction: str = "mean",
    ) -> None:
        super().__init__(std_p, std_t, log, reduction)
        assert alpha > 0, "Alpha must be greater than 0."
        assert alpha != 1, (
            "Alpha must not be equal to 1. Use GaussianKLDivLoss instead."
        )
        self.alpha = alpha

    def forward(self, predictions: Tensor, targets: Tensor) -> Tensor:
        mu_p, sigma_p, mu_t, sigma_t = self._extract_mus_and_sigmas(
            predictions, targets
        )
        # Compute sigma_alpha^2 and sigma_alpha
        var_alpha = (1 - self.alpha) * sigma_p.pow(2) + self.alpha * sigma_t.pow(2)
        sigma_alpha = var_alpha.sqrt()
        # Compute the two loss terms
        first_term = self.alpha * (mu_t - mu_p).pow(2) / (2 * var_alpha)
        second_term = (1 / (1 - self.alpha)) * torch.log(
            sigma_alpha / (sigma_p.pow(1 - self.alpha) * sigma_t.pow(self.alpha))
        )

        return self.reduce(first_term + second_term)


class GaussianKLDivLoss(GaussianCustomLoss):
    """
    Compute the KL divergence loss between two Gaussian distributions, computed
    as the unnumbered equation after Eq. 18 in https://arxiv.org/pdf/1206.2459.
    If std_p or std_t are not specified, the predictions and targets are
    assumed to be the concatenation of the tensors of the mean and standard
    deviation of the Gaussian distributions. If std_p and std_t are specified,
    the predictions and targets are assumed to be the means of the Gaussian
    distributions.
    """

    def __init__(
        self,
        std_p: float | None = None,
        std_t: float | None = None,
        log: bool = False,
        reduction: str = "mean",
    ) -> None:
        super().__init__(std_p, std_t, log, reduction)

    def forward(self, predictions: Tensor, targets: Tensor) -> Tensor:
        mu_p, sigma_p, mu_t, sigma_t = self._extract_mus_and_sigmas(
            predictions, targets
        )
        var_p, var_t = sigma_p.pow(2), sigma_t.pow(2)

        return self.reduce(
            0.5
            * (
                (mu_t - mu_p).pow(2) / var_t
                + torch.log(var_t / var_p)
                + var_p / var_t
                - 1
            )
        )


class SameSTDGaussianRenyiAlphaDivLoss(CustomLoss):
    """
    Compute the Renyi-alpha divergence loss between two Gaussian distributions
    with the same standard deviation, computed as Eq. 10 in
    https://arxiv.org/pdf/1206.2459 where the second term becomes 0 due to
    sigma_1 = sigma_2. Thus this loss allows alpha to be set to 1.0,
    degenerating into the KL divergence loss. The predictions and targets are
    assumed to be the means of the Gaussian distributions.
    """

    def __init__(self, alpha: float, std: float = 1.0, reduction: str = "mean") -> None:
        assert alpha > 0, "Alpha must be greater than 0."
        super().__init__(reduction=reduction)
        self.mse = nn.MSELoss(reduction=reduction)
        self.alpha = alpha
        self.set_std(std)

    def set_std(self, std: float) -> None:
        """Set the standard deviation of the Gaussian distributions.

        Args:
            std (float): The standard deviation of the Gaussian distributions.
        """
        assert std != 0, "Standard deviation must not be zero."
        self.var = std**2

    def forward(self, predictions: Tensor, targets: Tensor) -> Tensor:
        return self.alpha * self.mse(predictions, targets) / (2 * self.var)


class VAEKLDivLoss(CustomLoss):
    """
    Compute the KL divergence loss in a Variational Autoencoder (VAE) setting.
    It acts as a regularizer to make the latent space distribution close to a
    standard normal distribution. The formula is derived from Appendix B of
    https://arxiv.org/pdf/1312.6114
    """

    def __init__(self, reduction: str = "mean") -> None:
        super().__init__(reduction=reduction)

    def forward(self, mu: Tensor, logvar: Tensor) -> Tensor:
        return super().reduce(-0.5 * (1 + logvar - mu.pow(2) - logvar.exp()).sum(dim=1))


class CustomVAELoss(CustomLoss):
    """
    Base class for custom VAE loss functions.
    """

    def __init__(self, reduction: str = "mean") -> None:
        super().__init__(reduction=reduction)

    @abstractmethod
    def forward(  # type: ignore
        self, predictions: Tensor, targets: Tensor, mu: Tensor, logvar: Tensor
    ) -> Tensor:
        """
        Compute the VAE-style loss.

        Args:
            predictions (Tensor): The predicted values.
            targets (Tensor): The target values.
            mu (Tensor): The mean of the latent representation.
            logvar (Tensor): The log variance of the latent representation.

        Returns:
            Tensor: The computed loss.
        """


class BetaVAELoss(CustomVAELoss):
    """
    Compute the BetaVAE loss, which is a combination of the reconstruction loss
    and the KL divergence loss weighted by a beta parameter. The reconstruction
    loss is the MSE loss between the predictions and targets, and the KL
    divergence loss is computed using the KL divergence of the distribution of
    the latent space, making it so that it tends to be a standard normal
    distribution.
    """

    def __init__(self, beta: float, reduction: str = "mean") -> None:
        super().__init__(reduction=reduction)
        self.set_beta(beta)
        self.mse = nn.MSELoss(reduction=reduction)
        self.kl = VAEKLDivLoss(reduction=reduction)

    def forward(  # type: ignore
        self, predictions: Tensor, targets: Tensor, mu: Tensor, logvar: Tensor
    ) -> Tensor:
        return self.mse(predictions, targets) + self.beta * self.kl(mu, logvar)

    def set_beta(self, beta: float) -> None:
        """Set the beta value.

        Args:
            beta (float): The beta value for the VAE loss.
        """
        assert beta > 0, "Beta must be greater than 0."
        self.beta = beta
