import torch
import torch.nn.functional as F
import torch.nn as nn
import math
import numpy as np

from torch.distributions import Normal, Categorical, Dirichlet, kl_divergence


class EntropyWeightedMAELoss(torch.nn.Module):
    def __init__(self, ent_weight=1.0):
        super(EntropyWeightedMAELoss, self).__init__()
        self.__ent_weight = ent_weight

    @property
    def ent_weight(self):
        return self.__ent_weight

    def forward(self, log_probs, targets, reduction='mean'):
        p = log_probs.exp()  # prob
        H = (-p*p.log()).sum(dim=-1)  # entropy
        f_j = torch.gather(p, 1, torch.unsqueeze(targets, 1))

        #loss = (1 + H.detach() * self.__ent_weight) * (2 - 2*f_j)
        loss = (1 + (1/torch.exp(H.detach())) * self.__ent_weight) * (2 - f_j)
        if reduction == 'mean':
            return loss.mean()
        else:
            return loss

    @ent_weight.setter
    def ent_weight(self, ent_weight):
        if self.__ent_weight < 0.0:
            raise ValueError("entropy weight must be a positive value.")
        self.__ent_weight = ent_weight


class MAELoss(torch.nn.Module):
    def __init__(self):
        super(MAELoss, self).__init__()

    def forward(self, logit, targets, reduction='mean', eps=1e-3):
        p = F.softmax(logit, dim=-1)
        f_j = torch.gather(p, 1, torch.unsqueeze(targets, 1))
        loss = (1 - f_j + eps)

        if reduction == 'mean':
            return loss.mean()
        else:
            return loss


class GeneralizedCELoss(torch.nn.Module):
    def __init__(self, q=0.4, reduction='none'):
        super(GeneralizedCELoss, self).__init__()
        self.q = q
        self.reduction = reduction
    def forward(self, logits, targets):
        p = F.softmax(logits, dim=1)
        if np.isnan(p.mean().item()):
            raise NameError('GCE_p')
        Yg = torch.gather(p, 1, torch.unsqueeze(targets, 1))
        # modify gradient of cross entropy
        loss_weight = (Yg.squeeze().detach()**self.q)*self.q
        if np.isnan(Yg.mean().item()):
            raise NameError('GCE_Yg')

        loss = F.cross_entropy(logits, targets, reduction='none') * loss_weight
        if self.reduction == 'mean':
            loss = loss.mean()
        return loss


class EntropyWeightedGCELoss(torch.nn.Module):
    def __init__(self, q=0.2, ent_weight=1.0):
        super(EntropyWeightedGCELoss, self).__init__()
        self.q = q
        self.ent_weight = ent_weight

    @property
    def ent_weight(self):
        return self.__ent_weight

    def forward(self, logits, targets, reduction='mean'):
        p = F.softmax(logits, dim=1)
        H = (-p*p.log()).sum(dim=-1)  # entropy

        if np.isnan(p.mean().item()):
            raise NameError('GCE_p')
        Yg = torch.gather(p, 1, torch.unsqueeze(targets, 1))
        # modify gradient of cross entropy
        loss_weight = (Yg.squeeze().detach()**self.q)*self.q
        if np.isnan(Yg.mean().item()):
            raise NameError('GCE_Yg')

        loss = F.cross_entropy(logits, targets, reduction='none') * loss_weight
        loss = self.ent_weight
        if reduction == 'mean':
            loss = loss.mean()
        return loss

    @ent_weight.setter
    def ent_weight(self, ent_weight):
        if self.__ent_weight < 0.0:
            raise ValueError("entropy weight must be a positive value.")


class NoisyDirichletLoss(nn.Module):
    def __init__(self, params, num_classes=10, noise=1e-2, prior_scale=1,
                 reduction='mean', likelihood_temp=1):
        super().__init__()

        assert noise > 0

        self.reduction = reduction

        self.theta = params
        self.C = num_classes
        self.ae = noise
        self.sigma = prior_scale
        self.T = likelihood_temp

    def forward(self, logits, Y, N=1):
        alpha = F.one_hot(Y, self.C) + self.ae
        gamma_var = (1 / alpha + 1).log()
        gamma_mean = alpha.log() - gamma_var / 2
        p_obs = Normal(gamma_mean, gamma_var.sqrt())
        energy = - p_obs.log_prob(logits).sum(dim=-1)
        if self.reduction == 'mean':
            energy = energy.mean(dim=-1).div(self.T)
        else:
            energy = energy.sum(dim=-1).div(self.T)

        # for p in self.theta:
        #    prior = Normal(torch.zeros_like(p), self.sigma)
        #    energy -= prior.log_prob(p).sum().div(N)

        return energy


def gaussian_nll_loss(
    input,
    target,
    var,
    full: bool = False,
    eps: float = 1e-6,
    reduction: str = "mean",
):
    r"""Gaussian negative log likelihood loss.

    See :class:`~torch.nn.GaussianNLLLoss` for details.

    Args:
        input: expectation of the Gaussian distribution.
        target: sample from the Gaussian distribution.
        var: tensor of positive variance(s), one for each of the expectations
            in the input (heteroscedastic), or a single one (homoscedastic).
        full (bool, optional): include the constant term in the loss calculation. Default: ``False``.
        eps (float, optional): value added to var, for stability. Default: 1e-6.
        reduction (string, optional): specifies the reduction to apply to the output:
            ``'none'`` | ``'mean'`` | ``'sum'``. ``'none'``: no reduction will be applied,
            ``'mean'``: the output is the average of all batch member losses,
            ``'sum'``: the output is the sum of all batch member losses.
            Default: ``'mean'``.
    """
    # Check var size
    # If var.size == input.size, the case is heteroscedastic and no further checks are needed.
    # Otherwise:
    if var.size() != input.size():
        # If var is one dimension short of input, but the sizes match otherwise, then this is a homoscedastic case.
        # e.g. input.size = (10, 2, 3), var.size = (10, 2)
        # -> unsqueeze var so that var.shape = (10, 2, 1)
        # this is done so that broadcasting can happen in the loss calculation
        if input.size()[:-1] == var.size():
            var = torch.unsqueeze(var, -1)

        # This checks if the sizes match up to the final dimension, and the final dimension of var is of size 1.
        # This is also a homoscedastic case.
        # e.g. input.size = (10, 2, 3), var.size = (10, 2, 1)
        elif input.size()[:-1] == var.size()[:-1] and var.size(-1) == 1:  # Heteroscedastic case
            pass

        # If none of the above pass, then the size of var is incorrect.
        else:
            raise ValueError("var is of incorrect size")

    # Check validity of reduction mode
    if reduction != 'none' and reduction != 'mean' and reduction != 'sum':
        raise ValueError(reduction + " is not valid")

    # Entries of var must be non-negative
    if torch.any(var < 0):
        raise ValueError("var has negative entry/entries")

    # Clamp for stability
    var = var.clone()
    with torch.no_grad():
        var.clamp_(min=eps)

    # Calculate the loss
    loss = 0.5 * (torch.log(var) + (input - target)**2 / var)
    if full:
        loss += 0.5 * math.log(2 * math.pi)

    if reduction == 'mean':
        return loss.mean()
    elif reduction == 'sum':
        return loss.sum()
    else:
        return loss
