import torch
from torch.nn import functional as F


__all__ = ['contrastive_loss', 'contrastive_loss_single']


def contrastive_loss_single(z1, z2, temp: float = 1.):
    z = torch.cat([z1, z2], dim=0)
    sim = torch.matmul(z, z.T) / temp

    logits = torch.tril(sim, diagonal=-1)[..., :-1]
    logits += torch.triu(sim, diagonal=1)[..., 1:]

    logits = -F.log_softmax(logits, dim=-1)

    n = z1.size(0)
    i = torch.arange(z1.size(0), device=z1.device)
    return (logits[i, i + n - 1].mean() + logits[i + n, i].mean()) / 2


# z1, z2 \in R^{B, T, C}
def contrastive_loss(z1, z2, temp: float = 1., mode: str = 'both', weight: float = None):
    if weight is None:
        weight = [0.5, 0.5]

    assert mode.lower() in ['instance', 'temporal', 'both']

    if mode.lower() == 'instance':
        loss = instance_contrastive_loss(z1, z2, temp)
    elif mode.lower() == 'temporal':
        loss = temporal_contrastive_loss(z1, z2, temp)
    else:
        loss = weight[0] * instance_contrastive_loss(z1, z2, temp) \
               + weight[1] * temporal_contrastive_loss(z1, z2, temp)

    return loss


def contrastive(z1, z2, temp):
    dim0, dim1, dim2 = z1.size(0), z1.size(1), z1.size(2)
    if dim1 == 0:
        return torch.tensor([0.], device=z1.device)

    z = torch.cat([z1, z2], dim=1)
    sim = torch.matmul(z, z.transpose(1, 2)) / temp

    logits = torch.tril(sim, diagonal=-1)[..., :-1]
    logits += torch.triu(sim, diagonal=1)[..., 1:]

    if dim1 > 1500:
        z, sim = z.cpu(), sim.cpu()
        torch.cuda.empty_cache()

    logits = -F.log_softmax(logits, dim=-1)

    n = torch.arange(dim1, device=z1.device)
    loss = (logits[:, n, dim1 + n - 1].mean() + logits[:, dim1 + n, n].mean()) / 2
    return loss


def instance_contrastive_loss(z1, z2, temp):
    return contrastive(z1.transpose(0, 1), z2.transpose(0, 1), temp)


def temporal_contrastive_loss(z1, z2, temp):
    return contrastive(z1, z2, temp)
