import torch
import torch.nn as nn
import torch.nn.functional as F
import numpy as np


eps = 1e-7
class VCELoss(nn.Module):
    def __init__(self, a, scale=1, eps=eps):
        super(VCELoss, self).__init__()
        self.a = a
        self.scale = scale
        self.eps = eps
    def forward(self, input, target):
        input = F.softmax(input, dim=1)
        input = input + self.a
        input = torch.clamp(input, min=self.eps)
        log_soft_out = torch.log(input)
        loss = F.nll_loss(log_soft_out, target)
        return self.scale * loss.mean()

class VELoss(nn.Module):
    def __init__(self, a, scale=1, num_classes=10):
        super(VELoss, self).__init__()
        self.a = a
        self.scale = scale
        self.num_classes = num_classes
    def forward(self, input, target):
        pred = F.softmax(input, dim=1)
        label_one_hot = F.one_hot(target, self.num_classes).float().to(pred.device)
        x = -torch.sum(label_one_hot * pred, dim=1)
        loss = torch.exp(x * torch.log(torch.tensor(self.a, dtype=x.dtype)))
        return loss.mean() * self.scale

class VMSELoss(nn.Module):
    def __init__(self, a=0, num_classes=10, scale=1):
        super(VMSELoss, self).__init__()
        self.num_classes = num_classes
        self.a = a
        self.scale = scale
    def forward(self, input, labels):
        input = F.softmax(input, dim=1)
        label_one_hot = F.one_hot(labels, self.num_classes).float()
        label_one_hot = label_one_hot * self.a
        loss = (input - label_one_hot)**2
        return self.scale * loss.mean()
    

class NCEandVCE(nn.Module):
    def __init__(self, alpha=1, beta=1, a=0, num_classes=10):
        super(NCEandVCE, self).__init__()
        self.nce = NCELoss(num_classes=num_classes, scale=alpha)
        self.vce = VCELoss(a=a, scale=beta)
  
    def forward(self, input, target):
        loss = self.nce(input, target) + self.vce(input, target)
        return loss
class NCEandVEL(nn.Module):
    def __init__(self, alpha=1, beta=1, a=2.7, num_classes=10):
        super(NCEandVEL, self).__init__()
        self.nce = NCELoss(num_classes=num_classes, scale=alpha)
        self.vel = VELoss(a=a, scale=beta, num_classes=num_classes)
  
    def forward(self, input, target):
        loss = self.nce(input, target) + self.vel(input, target)
        return loss
class NCEandVMSE(nn.Module):
    def __init__(self, alpha=1, beta=1, a=0, num_classes=10):
        super(NCEandVMSE, self).__init__()
        self.nce = NCELoss(num_classes=num_classes, scale=alpha)
        self.vmse = VMSELoss(a=a, scale=beta, num_classes=num_classes)
  
    def forward(self, input, target):
        loss = self.nce(input, target) + self.vmse(input, target)
        return loss


class NCELoss(nn.Module):
    def __init__(self, num_classes, scale=1.0):
        super(NCELoss, self).__init__()
        self.num_classes = num_classes
        self.scale = scale

    def forward(self, pred, labels):
        pred = F.log_softmax(pred, dim=1)
        label_one_hot = F.one_hot(labels, self.num_classes).float().to(pred.device)
        loss = -1 * torch.sum(label_one_hot * pred, dim=1) / (-pred.sum(dim=1))
        return self.scale * loss.mean()

