import abc

import torch
import torch.nn as nn


class KLDivergenceLoss(nn.Module):
    def __init__(self):
        super().__init__()

    def forward(self, logq: torch.Tensor):
        """
        logq: log probability of samples from the true distribution according to the normalizing flow.
        """
        return -torch.mean(logq)


class ReverseKLDivergenceLoss(nn.Module):
    def __init__(self):
        super().__init__()

    def forward(self, logp: torch.Tensor, logq: torch.Tensor):
        """
        logp: log probability of samples from the normalizing flow according to the true distribution.
        logq: log probability of samples from the normalizing flow according to the normalizing flow.
        """
        return -torch.mean(logq - logp)


class SquaredHellingerDistanceLoss(nn.Module):
    def __init__(self):
        super().__init__()

    def forward(self, logp: torch.Tensor, logq: torch.Tensor):
        """
        logp: log probability of samples from the normalizing flow according to the true distribution.
        logq: log probability of samples from the normalizing flow according to the normalizing flow.
        """
        return torch.mean(torch.expm1(0.5 * logp - 0.5 * logq) ** 2)


class TotalVariationDistanceLoss(nn.Module):
    def __init__(self):
        super().__init__()

    def forward(self, logp: torch.Tensor, logq: torch.Tensor):
        """
        logp: log probability of samples from the normalizing flow according to the true distribution.
        logq: log probability of samples from the normalizing flow according to the normalizing flow.
        """
        return torch.mean(torch.abs(torch.expm1(logp - logq)))


class ChiSquaredDivergenceLoss(nn.Module):
    def __init__(self):
        super().__init__()

    def forward(self, logp: torch.Tensor, logq: torch.Tensor):
        """
        logp: log probability of samples from the normalizing flow according to the true distribution.
        logq: log probability of samples from the normalizing flow according to the normalizing flow.
        """
        return torch.mean(torch.expm1(logp - logq) ** 2)


class AlphaDivergenceLoss(nn.Module):
    def __init__(self, alpha=0.0):
        super().__init__()
        self.alpha = alpha
        if self.alpha == 1:
            raise NotImplementedError("Use KLDivergenceLoss instead.")
        elif self.alpha == -1:
            raise NotImplementedError("Use ReverseKLDivergenceLoss instead.")
        if not -1 <= self.alpha <= 1:
            raise NotImplementedError

    def forward(self, logp: torch.Tensor, logq: torch.Tensor):
        """
        logp: log probability of samples from the normalizing flow according to the true distribution.
        logq: log probability of samples from the normalizing flow according to the normalizing flow.
        """
        return -torch.mean(torch.exp((logp - logq) * (1 + self.alpha) / 2))


class JensenShannonDivergenceLoss(nn.Module):
    def __init__(self):
        super().__init__()

    def forward(self,
                logp_from_p: torch.Tensor,
                logq_from_p: torch.Tensor,
                logp_from_q: torch.Tensor,
                logq_from_q: torch.Tensor):
        """
        logp_from_p: log probability of samples from the true distribution according to the true distribution.
        logq_from_p: log probability of samples from the true distribution according to the normalizing flow.
        logp_from_q: log probability of samples from the normalizing flow according to the true distribution.
        logq_from_q: log probability of samples from the normalizing flow according to the normalizing flow.
        """
        t0 = -torch.mean(logq_from_p)
        t1 = - torch.mean(torch.log1p(logp_from_p.exp() / logq_from_p.exp()))
        t2 = - torch.mean(torch.log1p(logp_from_q.exp() / logq_from_q.exp()))
        return t0 + t1 + t2
