import torch
import torch.nn.functional as F

def instance_contrastive_loss(z1, z2):
    B, T = z1.size(0), z1.size(1)
    if B == 1:
        return z1.new_tensor(0.)
    z = torch.cat([z1, z2], dim=0)  # 2B x L x V
    z = z.transpose(0, 1)  # L x 2B x V
    sim = torch.matmul(z, z.transpose(1, 2))  # L x 2B x 2B
    logits = torch.tril(sim, diagonal=-1)[:, :, :-1]    # L x 2B x (2B-1)
    logits += torch.triu(sim, diagonal=1)[:, :, 1:]
    logits = -F.log_softmax(logits, dim=-1)
    
    i = torch.arange(B, device=z1.device)
    loss = (logits[:, i, B + i - 1].mean() + logits[:, B + i, i].mean()) / 2
    return loss

def temporal_contrastive_loss(z1, z2):
    B, T = z1.size(0), z1.size(1)
    if T == 1:
        return z1.new_tensor(0.)
    z = torch.cat([z1, z2], dim=1)  # B x 2L x V
    sim = torch.matmul(z, z.transpose(1, 2))  # B x 2L x 2L
    logits = torch.tril(sim, diagonal=-1)[:, :, :-1]    # B x 2L x (2L-1)
    logits += torch.triu(sim, diagonal=1)[:, :, 1:]
    logits = -F.log_softmax(logits, dim=-1)
    
    t = torch.arange(T, device=z1.device)
    loss = (logits[:, t, T + t - 1].mean() + logits[:, T + t, t].mean()) / 2
    return loss
