import torch
import torch.nn as nn


class CrossEntropyLabelSmooth(nn.Module):
    """Cross entropy loss with label smoothing regularizer.
    Reference:
    Szegedy et al. Rethinking the Inception Architecture for Computer Vision. CVPR 2016.
    Equation: y = (1 - epsilon) * y + epsilon / K.
    Args:
        num_classes (int): number of classes.
        epsilon (float): weight.
    """

    def __init__(self, num_classes, epsilon=0.1, use_gpu=True, reduction=True):
        super(CrossEntropyLabelSmooth, self).__init__()
        self.num_classes = num_classes
        self.epsilon = epsilon
        self.use_gpu = use_gpu
        self.reduction = reduction
        self.logsoftmax = nn.LogSoftmax(dim=1)

    def forward(self, inputs, targets):
        """
        Args:
            inputs: prediction matrix (before softmax) with shape (batch_size, num_classes)
            targets: ground truth labels with shape (num_classes)
        """
        log_probs = self.logsoftmax(inputs)
        targets = torch.zeros(log_probs.size()).scatter_(1, targets.unsqueeze(1).cpu(), 1)
        if self.use_gpu: targets = targets.cuda()
        targets = (1 - self.epsilon) * targets + self.epsilon / self.num_classes
        loss = (- targets * log_probs).sum(dim=1)
        if self.reduction:
            return loss.mean()
        else:
            return loss

def get_ce_or_bce_loss(discr, dim_y: int, reduction: str = "mean"):
    if dim_y == 1:
        celossobj = torch.nn.BCEWithLogitsLoss(reduction=reduction)
        celossfn = lambda x, y, t: celossobj(discr(x, t), y.float())
    else:
        celossobj = torch.nn.CrossEntropyLoss(reduction=reduction)
        celossfn = lambda x, y, t: celossobj(discr(x, t), y)
    return celossobj, celossfn

def add_ce_loss(lossobj, celossfn):
    def lossfn(*x_y_maybext_niter):
        loss = 1. * lossobj(*x_y_maybext_niter[:-1])

        loss1 = 1. * celossfn(*x_y_maybext_niter[:3])

        return loss+loss1, (loss.item(), loss1.item())

    return lossfn

def get_lossfn(discr, frame, dim_y):
    celossfn = get_ce_or_bce_loss(discr, dim_y, 'mean')[1]

    lossobj = frame.get_lossfn(0, 'mean', "defl", wlogpi=1.0 / 1.0)
    lossfn = add_ce_loss(lossobj, celossfn)
    return lossfn