import torch
from torch.nn import functional as F


def dv_bound_loss(representation_1, representation_2, temperature=0.07):
    representation_1 = F.normalize(representation_1, dim=-1)
    representation_2 = F.normalize(representation_2, dim=-1)
    similarity_matrix = torch.mm(representation_1, representation_2.t()) / temperature
    
    first_term = similarity_matrix.diag().mean()
    
    mask = ~torch.eye(similarity_matrix.size(0), dtype=torch.bool, device=similarity_matrix.device)
    second_term = torch.log(torch.mean(torch.exp(similarity_matrix[mask])))
    
    loss = -(first_term - second_term)

    return loss


def cosine_similarity_loss(representation_1, representation_2):
    loss = -F.cosine_similarity(representation_1, representation_2, dim=1).mean()    
    
    return loss


# Adapted from https://github.com/google-research/simclr
def infoNCE_loss(representations_1, representations_2, temperature=0.07):
    batch_size = representations_1.shape[0]
    
    representations = torch.cat([representations_1, representations_2], dim=0) # Shape: [2N, D]
    representations = F.normalize(representations, dim=1)

    similarity_matrix = torch.matmul(representations, representations.T) # Shape: [2N, 2N]
    logits = similarity_matrix / temperature
    
    mask = torch.eye(2 * batch_size, dtype=torch.bool, device=representations.device)
    logits = logits.masked_fill(mask, -1e9)
    
    labels = torch.cat([
        torch.arange(batch_size, 2 * batch_size), # Labels at i+N for first N rows
        torch.arange(0, batch_size) # Labels at i-N for last N rows
    ]).to(representations.device)
    
    loss = F.cross_entropy(logits, labels)
    
    return loss


def mse_loss(representation_1, representation_2):
    squared_diff = (representation_1 - representation_2) ** 2
    loss = squared_diff.sum(dim=1).mean()

    return loss


# Adapted from https://github.com/facebookresearch/moco-v3
def contrastive_loss(query, key, temperature=0.07):
    batch_size = query.shape[0]

    query = F.normalize(query, dim=1)
    key = F.normalize(key, dim=1)

    logits = torch.einsum('nc,mc->nm', [query, key]) / temperature
    labels = torch.arange(batch_size, dtype=torch.long, device=logits.device)

    loss = F.cross_entropy(logits, labels) * (2 * temperature)

    return loss


def cross_entropy_loss(logits, labels):
    loss = F.cross_entropy(logits, labels)
    
    return loss


# Adapted from https://github.com/facebookresearch/barlowtwins
def off_diagonal(x):
    # return a flattened view of the off-diagonal elements of a square matrix
    n, m = x.shape
    assert n == m
    return x.flatten()[:-1].view(n - 1, n + 1)[:, 1:].flatten()

def diversity_loss(representation_1, representation_2, lambd=0.0051):
    batch_size = representation_1.shape[0]

    c = representation_1.T @ representation_2
    c.div_(batch_size)

    on_diag = torch.diagonal(c).add_(-1).pow_(2).sum()
    off_diag = off_diagonal(c).pow_(2).sum()

    loss = on_diag + lambd * off_diag

    return loss


# Adapted from https://github.com/facebookresearch/vicreg
def VICReg_loss(representation_1, representation_2, similarity_coeff=25.0, variance_coeff=25.0, covariance_coeff=1.0, eps=1e-4):
    x = representation_1
    y = representation_2

    similarity_loss = F.mse_loss(x, y)

    x_centered = x - x.mean(dim=0)
    y_centered = y - y.mean(dim=0)

    std_x = torch.sqrt(x_centered.var(dim=0) + eps)
    std_y = torch.sqrt(y_centered.var(dim=0) + eps)

    variance_loss_x = torch.mean(F.relu(1.0 - std_x))
    variance_loss_y = torch.mean(F.relu(1.0 - std_y))
    variance_loss = (variance_loss_x + variance_loss_y) / 2

    batch_size, feature_dim = x.shape

    covariance_matrix_x = (x_centered.T @ x_centered) / (batch_size - 1)
    covariance_matrix_y = (y_centered.T @ y_centered) / (batch_size - 1)

    off_diagonal_x = off_diagonal(covariance_matrix_x)
    off_diagonal_y = off_diagonal(covariance_matrix_y)

    covariance_loss_x = (off_diagonal_x ** 2).sum() / feature_dim
    covariance_loss_y = (off_diagonal_y ** 2).sum() / feature_dim
    covariance_loss = covariance_loss_x + covariance_loss_y

    loss = (
        similarity_coeff * similarity_loss
        + variance_coeff * variance_loss
        + covariance_coeff * covariance_loss
    )

    return loss
