import torch

import xentropy_cuda


class SoftmaxCrossEntropyLoss(torch.autograd.Function):
    @staticmethod
    def forward(ctx, logits, labels, smoothing=0.0, padding_idx=0, half_to_float=False):
        losses, max_log_sum_exp = xentropy_cuda.forward(
            logits, labels, smoothing, half_to_float)
        losses.masked_fill_(labels==padding_idx, 0)

        ctx.save_for_backward(logits, max_log_sum_exp, labels,
            torch.FloatTensor([smoothing]),
            torch.LongTensor([padding_idx]))

        return losses

    @staticmethod
    def backward(ctx, grad_loss):
        logits, max_log_sum_exp, labels, smoothing, padding_idx = ctx.saved_tensors

        if not grad_loss.is_contiguous():
            grad_loss = grad_loss.contiguous()
        grad_loss.masked_fill_(labels==padding_idx.item(), 0)
        grad_logits = xentropy_cuda.backward(
            grad_loss.contiguous(), logits, max_log_sum_exp,
            labels, smoothing.item())

        return grad_logits, None, None, None, None
