import torch
import torch.nn as nn
import torch.nn.functional as F


def norm(w):
    return torch.norm(torch.cat([v.flatten() for v in w.values()])).item()


class global_NT_xentloss(nn.Module):
    r"""
    NT_xentloss definition adapted from https://github.com/PatrickHua/SimSiam
    Arguments:
        z1 (torch.tensor): the embedding of model .
        z2 (torch.tensor): the embedding of model using another augmentation.
    returns:
        loss: the NT_xentloss loss for this batch data
    :rtype:
        torch.FloatTensor
    """
    def __init__(self, temperature=0.1, device=torch.device("cpu")):
        super(global_NT_xentloss, self).__init__()
        self.temperature = temperature
        self.device = device

    def forward(self, z1, z2, others_z2=[]):
        N, Z = z1.shape
        z1, z2 = z1.to(self.device), z2.to(self.device)
        representations = torch.cat([z1, z2], dim=0)
        similarity_matrix = F.cosine_similarity(representations.unsqueeze(1),
                                                representations.unsqueeze(0),
                                                dim=-1)

        l_pos = torch.diag(similarity_matrix, N)
        r_pos = torch.diag(similarity_matrix, -N)
        positives = torch.cat([l_pos, r_pos]).view(2 * N, 1)

        diag = torch.eye(2 * N, dtype=torch.bool, device=self.device)
        diag[N:, :N] = diag[:N, N:] = diag[:N, :N]
        negatives = similarity_matrix[~diag].view(2 * N, -1)

        if len(others_z2) != 0:
            for z2_ in others_z2:
                z2_ = z2_.detach().to(self.device)
                N2, Z2 = z2_.shape
                representations = torch.cat([z1, z2_], dim=0)
                similarity_matrix = F.cosine_similarity(
                    representations.unsqueeze(1),
                    representations.unsqueeze(0),
                    dim=-1)
                mask = torch.zeros_like(similarity_matrix,
                                        dtype=torch.bool,
                                        device=self.device)
                mask[N:, :N] = True
                mask[:N, N:] = True
                negatives_other = similarity_matrix[mask].view(2 * N, -1)
                negatives = torch.cat([negatives, negatives_other], dim=1)

        logits = torch.cat([positives, negatives], dim=1) / self.temperature
        labels = torch.zeros(2 * N, dtype=torch.int64,
                             device=self.device)  # scalar label per sample
        loss = F.cross_entropy(logits, labels, reduction='sum') / (2 * N)

        return loss
