import torch


REG = 0.1


def compute_covariance(covariance_mat):
    cos = torch.nn.CosineSimilarity(dim=1, eps=1e-6)
    batch_size = covariance_mat.shape[0]
    n = covariance_mat.shape[1]
    cosine_matrix = torch.zeros((batch_size, n, n), dtype=torch.float32)
    for bs in range(batch_size):
        for i in range(n):
            cosine_matrix[bs, i] = cos(covariance_mat[bs], covariance_mat[bs, i].repeat(n,1))
    return cosine_matrix
