import torch
from torch.nn import functional as F
import math

def squared_euclidian_distance(a, b):
    return torch.cdist(a, b)**2


def cosine_similarity(a, b):
    return torch.mm(F.normalize(a, p=2, dim=-1), F.normalize(b, p=2, dim=-1).T)


def stable_cosine_distance(a, b, squared=True):
    """Computes the pairwise distance matrix with numerical stability."""
    mat = torch.cat([a, b])

    pairwise_distances_squared = torch.add(
        mat.pow(2).sum(dim=1, keepdim=True).expand(mat.size(0), -1),
        torch.t(mat).pow(2).sum(dim=0, keepdim=True).expand(mat.size(0), -1)
    ) - 2 * (torch.mm(mat, torch.t(mat)))

    # Deal with numerical inaccuracies. Set small negatives to zero.
    pairwise_distances_squared = torch.clamp(pairwise_distances_squared, min=0.0)

    # Get the mask where the zero distances are at.
    error_mask = torch.le(pairwise_distances_squared, 0.0)

    # Optionally take the sqrt.
    if squared:
        pairwise_distances = pairwise_distances_squared
    else:
        pairwise_distances = torch.sqrt(pairwise_distances_squared + error_mask.float() * 1e-16)

    # Undo conditionally adding 1e-16.
    pairwise_distances = torch.mul(pairwise_distances, (error_mask == False).float())

    # Explicitly set diagonals to zero.
    mask_offdiagonals = 1 - torch.eye(*pairwise_distances.size(), device=pairwise_distances.device)
    pairwise_distances = torch.mul(pairwise_distances, mask_offdiagonals)

    return pairwise_distances[:a.shape[0], a.shape[0]:]


def mahalanobis_distance(features, mu, sigma, eps=1e-7):
    # features: list of length T, each of size of BxD 
    # mu: list of length T, each of size CxD where C is the number of classes learnt in that task
    # sigma: list of length T, each of size DxD   
    
    assert len(features) == len(mu) and len(mu) == len(sigma)
    T = len(features)
    
    distance = []
    for i in range(T):
        sim = torch.matmul(F.normalize(features[i], p=2, dim=-1).unsqueeze(1), F.normalize(mu[i], p=2, dim=-1).T.unsqueeze(0))
        bx = torch.acos(torch.clamp(sim.squeeze(), -1 + eps, 1 - eps))     # BxC
        bL = sigma[i].unsqueeze(0)       #1xC
        
        var = bL ** 2
        log_scale = bL.log()
        
        log_prob = -(bx ** 2) / (2 * var) - log_scale - math.log(math.sqrt(2 * math.pi)) + math.log(2)
                        
        distance.append(1 - log_prob.exp())
    return torch.cat(distance, dim=1)
