import torch

# TODO: implement
def LSS(X):
    pass

def get_norms(X):
    return torch.norm(X, dim=-1)

def get_normalized(X):
    norms = get_norms(X)

    if len(X.shape) == 1:
        return X / norms
    return X / norms[:, :, None]

def get_difference(X, dim=0):
    """
    Compute difference of embeddings across layers.

    :param X: representations
    :param dim: differencing dimension, set to layers by default
    """
    return torch.diff(X, dim=dim)

def cosine_similarity(X1, X2, dim=-1):
    X1 = get_normalized(X1)
    X2 = get_normalized(X2)

    cos = torch.sum(X1 * X2, dim=dim)

    return cos

def cosine_similarity_kernel(X1, X2, dim=-1):
    X1 = get_normalized(X1)
    X2 = get_normalized(X2)
    X2 = torch.transpose(X2, -1, -2)

    return torch.matmul(X1, X2)
