import torch
import torch.nn.functional as F


class CrossEntropyLoss(torch.nn.Module):

    def __init__(self, label_smoothing=0.0):
        super().__init__()
        self.label_smoothing = label_smoothing

    def forward(self, logits, labels):
        confidence = 1.0 - self.label_smoothing
        logprobs = F.log_softmax(logits, dim=-1)
        nll_loss = -logprobs.gather(dim=-1, index=labels.unsqueeze(1))
        nll_loss = nll_loss.squeeze(1)
        smooth_loss = -logprobs.mean(dim=-1)
        loss = confidence * nll_loss + self.label_smoothing * smooth_loss
        loss_numpy = loss.data.cpu().numpy()
        num_batch = len(loss_numpy)
        return torch.sum(loss) / num_batch


class CrossEntropyLoss_Weighted(torch.nn.Module):

    def __init__(self, label_smoothing=0.0, k=5):
        super().__init__()
        self.label_smoothing = label_smoothing
        self.k = k

    def forward(self, logits, labels):
        logprobs = -F.log_softmax(logits, dim=-1)

        true_label_matrix = torch.zeros_like(logits)
        true_label_matrix.scatter_(1, labels.unsqueeze(1),
                                   1 - self.label_smoothing)

        loss = true_label_matrix * logprobs
        loss = loss.sum(dim=-1).mean()

        # compute pairwise distances
        logits_masked = logits.clone()
        logits_masked[true_label_matrix > 0] = torch.finfo(torch.float32).min
        weights, indices = torch.topk(logits_masked,
                                      k=self.k,
                                      dim=-1,
                                      largest=True)

        weights = weights.softmax(dim=1)
        true_label_matrix.scatter_(1, indices, weights * self.label_smoothing)

        loss = (true_label_matrix * logprobs).sum(dim=-1).mean()
        return loss