
import torch
from torch import nn
import torch.nn.functional as F


class Xent(nn.Module):
    def __init__(self, n_classes, add_bias=True):
        super(Xent, self).__init__()
        self.b = 0
        if add_bias:
            self.b = nn.Parameter(torch.zeros(n_classes))

    def forward(self, z, w, labels):
        return F.cross_entropy(
            torch.matmul(z, w.T) + self.b,
            labels.to(torch.long).to(z.device),
        )


class L2SoftmaxXent(nn.Module):
    def __init__(self, temperature=1.0):
        super(L2SoftmaxXent, self).__init__()
        self.temperature = temperature

    def forward(self, z, w, labels):
        # z.shape = (batch_size, embed_dim)
        # w.shape = (n_classes, embed_dim)
        z = F.normalize(z, p=2, dim=1)
        w = F.normalize(w, p=2, dim=1)

        cosine = torch.matmul(z, w.T)
        return F.cross_entropy(
            cosine/self.temperature,
            labels.to(torch.long).to(z.device),
        )


class CosineLoss(nn.Module):
    def forward(self, z, w, labels):
        # z.shape = (batch_size, embed_dim)
        # w.shape = (n_classes, embed_dim)
        z = F.normalize(z, p=2, dim=1)
        w = F.normalize(w[labels, ...], p=2, dim=1)

        cosine = torch.matmul(z, w.T)
        return - torch.diag(cosine).sum()


class NTXent(nn.Module):
    def __init__(self, temperature):
        super(NTXent, self).__init__()
        self.temperature = temperature

    def forward(self, z, w, labels):
        # z.shape = (batch_size, embed_dim)
        # w.shape = (n_classes, embed_dim)
        z = F.normalize(z, p=2, dim=1)
        w = F.normalize(w[labels, ...], p=2, dim=1)

        eye = torch.eye(z.shape[0], dtype=torch.float32).to(z.device)
        ones = torch.ones_like(eye)
        mask = torch.cat((ones - eye, ones), dim=1)

        cosine = mask*torch.cat((torch.matmul(z, z.T), torch.matmul(z, w.T)), dim=1)
        # cosine.shape = (batch_size, 2*batch_size)
        # pick out second diagonal
        diag = torch.diag(cosine, z.shape[0])

        # log-sum trick
        c = 1 / self.temperature

        numerator = diag / self.temperature
        denominator = c + torch.log(torch.sum(torch.exp(cosine / self.temperature - c), dim=1))
        return - (numerator - denominator).mean()


class NTXent2(nn.Module):
    # very slightly different implementation to above
    def __init__(self, temperature):
        super(NTXent, self).__init__()
        self.temperature = temperature

    def forward(self, z, w, labels):
        # z.shape = (batch_size, embed_dim)
        # w.shape = (n_classes, embed_dim)
        z = F.normalize(z, p=2, dim=1)
        w = F.normalize(w[labels, ...], p=2, dim=1)

        eye = torch.eye(z.shape[0], dtype=torch.float32).to(z.device)
        ones = torch.ones_like(eye)
        mask = torch.cat((ones - eye, ones), dim=1)

        cosine = torch.matmul(z, torch.cat((z, w), dim=0).T)
        # cosine.shape = (batch_size, 2*batch_size)
        # pick out second diagonal
        diag = torch.diag(cosine, z.shape[0])

        # log-sum trick
        c = 1 / self.temperature

        numerator = diag / self.temperature
        denominator = c + torch.log(torch.sum(torch.exp(cosine / self.temperature - c)*mask, dim=1))
        return - (numerator - denominator).mean()



# adapted from: https://github.com/HobbitLong/SupContrast/blob/master/losses.py
class SupConLoss(nn.Module):
    """Supervised Contrastive Learning: https://arxiv.org/pdf/2004.11362.pdf.
    It also supports the unsupervised contrastive loss in SimCLR"""
    def __init__(self, temperature, contrast_mode):
        super(SupConLoss, self).__init__()
        self.temperature = temperature
        self.contrast_mode = contrast_mode

    def forward(self, z, w, labels):
        """Compute loss for model. If both `labels` and `mask` are None,
        it degenerates to SimCLR unsupervised loss:
        https://arxiv.org/pdf/2002.05709.pdf

        Args:
            features: hidden vector of shape [bsz, n_views, ...].
            labels: ground truth of shape [bsz].
            mask: contrastive mask of shape [bsz, bsz], mask_{i,j}=1 if sample j
                has the same class as sample i. Can be asymmetric.
        Returns:
            A loss scalar.
        """
        z = F.normalize(z, p=2, dim=1)
        w = F.normalize(w[labels, ...], p=2, dim=1)

        features = torch.cat((torch.unsqueeze(z, dim=1), torch.unsqueeze(w, dim=1)), dim=1)

        device = (torch.device('cuda')
                  if features.is_cuda
                  else torch.device('cpu'))

        if len(features.shape) < 3:
            raise ValueError('`features` needs to be [bsz, n_views, ...],'
                             'at least 3 dimensions are required')
        if len(features.shape) > 3:
            features = features.view(features.shape[0], features.shape[1], -1)

        batch_size = features.shape[0]
        # mask from labels
        labels = labels.contiguous().view(-1, 1)
        if labels.shape[0] != batch_size:
            raise ValueError('Num of labels does not match num of features')
        mask = torch.eq(labels, labels.T).float().to(device)

        contrast_count = features.shape[1]
        contrast_feature = torch.cat(torch.unbind(features, dim=1), dim=0)

        if self.contrast_mode == 'one':
            anchor_feature = features[:, 0]
            anchor_count = 1
        elif self.contrast_mode == 'all':
            anchor_feature = contrast_feature
            anchor_count = contrast_count
        else:
            raise ValueError('Unknown mode: {}'.format(self.contrast_mode))

        # compute logits
        anchor_dot_contrast = torch.div(
            torch.matmul(anchor_feature, contrast_feature.T),
            self.temperature)
        # for numerical stability
        logits_max, _ = torch.max(anchor_dot_contrast, dim=1, keepdim=True)
        logits = anchor_dot_contrast - logits_max.detach()

        # tile mask
        mask = mask.repeat(anchor_count, contrast_count)
        # mask-out self-contrast cases
        logits_mask = torch.scatter(
            torch.ones_like(mask),
            1,
            torch.arange(batch_size * anchor_count).view(-1, 1).to(device),
            0
        )
        mask = mask * logits_mask

        # compute log_prob
        exp_logits = torch.exp(logits) * logits_mask
        log_prob = logits - torch.log(exp_logits.sum(1, keepdim=True))

        # compute mean of log-likelihood over positive
        # modified to handle edge cases when there is no positive pair
        # for an anchor point.
        # Edge case e.g.:-
        # features of shape: [4,1,...]
        # labels:            [0,1,1,2]
        # loss before mean:  [nan, ..., ..., nan]
        mask_pos_pairs = mask.sum(1)
        mask_pos_pairs = torch.where(mask_pos_pairs < 1e-6, 1, mask_pos_pairs)
        mean_log_prob_pos = (mask * log_prob).sum(1) / mask_pos_pairs

        # loss
        loss = - mean_log_prob_pos
        loss = loss.view(anchor_count, batch_size).mean()

        return loss

