import torch
from torch import nn
import numpy as np

eps = 1e-7


class NCECriterion(nn.Module):
    """
    Eq. (12): L_{NCE}
    """
    def __init__(self, n_data):
        super(NCECriterion, self).__init__()
        self.n_data = n_data

    def forward(self, x):  # x size (bs, nce_k+1, 1) or (bs, 2*(nce_k+1), 1)
        bsz = x.shape[0]
        m = x.size(1) - 1

        # noise distribution
        Pn = 1 / float(self.n_data)

        # loss for positive pair
        P_pos = x.select(1, 0)  # tensor.select(dim, index); why select pos 0?
        log_D1 = torch.div(P_pos, P_pos.add(m * Pn + eps)).log_()

        # loss for K negative pair
        P_neg = x.narrow(1, 1, m) # tensor.narrow(dim, start, length). The dimension dim is input from start to start + length
        log_D0 = torch.div(P_neg.clone().fill_(m * Pn), P_neg.add(m * Pn + eps)).log_()

        loss = - (log_D1.sum(0) + log_D0.view(-1, 1).sum(0)) / bsz

        return loss

class NCESoftmaxLoss(nn.Module):
    """Softmax cross-entropy loss (a.k.a., info-NCE loss in CPC paper)"""
    def __init__(self):
        super(NCESoftmaxLoss, self).__init__()
        self.criterion = nn.CrossEntropyLoss() # combines log_softmax and nll_loss

    def forward(self, x):
        bsz = x.shape[0]
        x = x.squeeze() # there are bsz classification predictions
        label = torch.zeros([bsz]).cuda().long() # for each prediction, the gt class label is 0
        loss = self.criterion(x, label)
        return loss


