# import math

import torch
import torch.nn.functional as F


@torch.compile
def sampled_info_nce_loss(features, indices, num_negatives=128, temperature=0.1):
    """
    features: torch.Tensor of shape (batch_size, num_dims)
    """
    batch_size = features.shape[0]

    eligibility_mask = ~(indices[:, None] == indices[None, :]).to(features.device)
    # eligibility_mask = ~torch.eye(batch_size, dtype=torch.bool, device=features.device)
    eligibility_mask = eligibility_mask & eligibility_mask.roll(
        shifts=batch_size // 2, dims=0
    )
    prob_mask = eligibility_mask / eligibility_mask.sum(1, keepdim=True)
    num_negatives = min(num_negatives, eligibility_mask.sum(1).min().item())
    negatives = torch.multinomial(prob_mask, num_negatives)

    # Find positive example -> batch_size//2 away from the original example
    pos_mask = torch.arange(batch_size, device=features.device).roll(
        shifts=batch_size // 2, dims=0
    )

    # Calculate cosine similarity
    negative_cos_sim = F.cosine_similarity(
        features[:, None, :], features[negatives], dim=-1
    )
    positive_cos_sim = F.cosine_similarity(features, features[pos_mask], dim=-1)

    positive_cos_sim = positive_cos_sim / temperature
    negative_cos_sim = negative_cos_sim / temperature
    cos_sim = torch.cat([positive_cos_sim.unsqueeze(1), negative_cos_sim], dim=1)
    # InfoNCE loss
    nll = -positive_cos_sim + torch.logsumexp(cos_sim, dim=-1)
    nll = nll.mean()
    # nll = InfoNCE(negative_mode='paired', temperature=temperature)(
    #     features, features[pos_mask], features[negatives]
    # )

    return nll


@torch.compile
def old_sampled_info_nce_loss(
    features, num_negatives=128, temperature=0.1, num_views=2
):
    """
    features: torch.Tensor of shape (batch_size, num_dims)
    """
    batch_size = features.shape[0]
    eligibility_mask = ~torch.eye(batch_size, dtype=torch.bool, device=features.device)
    prob_mask = eligibility_mask / (eligibility_mask.size(0) - 1)
    negatives = torch.multinomial(prob_mask, min(num_negatives, batch_size - 1))

    # Calculate cosine similarity
    negative_cos_sim = F.cosine_similarity(
        features[:, None, :], features[negatives], dim=-1
    )
    # Find positive example -> batch_size//2 away from the original example
    pos_mask = torch.arange(batch_size, device=negative_cos_sim.device).roll(
        shifts=batch_size // 2, dims=0
    )
    positive_cos_sim = F.cosine_similarity(
        features, features[pos_mask], dim=-1
    ).unsqueeze(1)
    # pos_indices = range(batch_size // num_views, batch_size, batch_size // num_views)
    # pos_mask = torch.stack(
    #     [
    #         torch.arange(batch_size, device=negative_cos_sim.device).roll(shifts=i, dims=0)
    #         for i in pos_indices
    #     ],
    #     dim=1,
    # )
    # positive_cos_sim = F.cosine_similarity(features[:, None, :], features[pos_mask], dim=-1)

    positive_cos_sim = positive_cos_sim / temperature
    negative_cos_sim = negative_cos_sim / temperature
    cos_sim = torch.cat([positive_cos_sim, negative_cos_sim], dim=1)
    # InfoNCE loss
    nll = -positive_cos_sim.sum(-1) + torch.logsumexp(cos_sim, dim=-1)
    nll = nll.mean()

    return nll


@torch.compile
def info_nce_loss(features, temperature=0.1):
    """
    features: torch.Tensor of shape (batch_size, num_dims)
    """
    # Calculate cosine similarity
    cos_sim = F.cosine_similarity(features[:, None, :], features[None, :, :], dim=-1)
    # Mask out cosine similarity to itself
    self_mask = torch.eye(cos_sim.shape[0], dtype=torch.bool, device=cos_sim.device)
    cos_sim.masked_fill_(self_mask, -9e15)
    # Find positive example -> batch_size//2 away from the original example
    pos_mask = self_mask.roll(shifts=cos_sim.shape[0] // 2, dims=0)
    # InfoNCE loss
    cos_sim = cos_sim / temperature
    nll = -cos_sim[pos_mask] + torch.logsumexp(cos_sim, dim=-1)
    nll = nll.mean()

    return nll


# @torch.compile
# def cosine_similarity(x1, dim=-1, eps=1e-8):
#     # get normalization value
#     x1_div = torch.linalg.vector_norm(x1, dim=dim, keepdims=True)

#     x1_div = x1_div.clone()
#     with torch.no_grad():
#         x1_div.clamp_(math.sqrt(eps))

#     # normalize, avoiding division by 0
#     t1_norm = x1 / x1_div

#     return (t1_norm * t1_norm).sum(dim=dim)
