import torch
import torch.nn as nn
import torch.nn.functional as F

from federatedscope.register import register_criterion


class NT_xentloss(nn.Module):
    r"""
    NT_xentloss definition adapted from https://github.com/PatrickHua/SimSiam
    Arguments:
        z1 (torch.tensor): the embedding of model .
        z2 (torch.tensor): the embedding of model using another augmentation.
    returns:
        loss: the NT_xentloss loss for this batch data
    :rtype:
        torch.FloatTensor
    """
    def __init__(self, temperature=0.1):
        super(NT_xentloss, self).__init__()
        self.temperature = temperature

    def forward(self, z1, z2):
        N, Z = z1.shape
        device = z1.device
        representations = torch.cat([z1, z2], dim=0)
        similarity_matrix = F.cosine_similarity(representations.unsqueeze(1),
                                                representations.unsqueeze(0),
                                                dim=-1)

        l_pos = torch.diag(similarity_matrix, N)
        r_pos = torch.diag(similarity_matrix, -N)
        positives = torch.cat([l_pos, r_pos]).view(2 * N, 1)

        diag = torch.eye(2 * N, dtype=torch.bool, device=device)
        diag[N:, :N] = diag[:N, N:] = diag[:N, :N]
        negatives = similarity_matrix[~diag].view(2 * N, -1)

        logits = torch.cat([positives, negatives], dim=1) / self.temperature
        labels = torch.zeros(2 * N, device=device,
                             dtype=torch.int64)  # scalar label per sample
        loss = F.cross_entropy(logits, labels, reduction='sum') / (2 * N)

        return loss


def create_NT_xentloss(type, device):

    if type == 'NT_xentloss':
        criterion = NT_xentloss().to(device)

        return criterion


register_criterion('NT_xentloss', create_NT_xentloss)
