import torch
from latentis.measure.functional.cka import cka as cka_fn
from latentis.measure.functional.cka import kernel_hsic, linear_hsic
from latentis.measure.functional.svcca import robust_svcca as svcca_fn

from presto import Presto
from sklearn.random_projection import GaussianRandomProjection as Gauss

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")


def pairwise_embedding_cosine_similarity(embeddings, limit=100):

    embeddings = torch.stack(embeddings)
    embeddings = embeddings[:, :limit]

    embedding_a = embeddings.unsqueeze(1)
    embedding_b = embeddings.unsqueeze(2)

    numerator = torch.sum((embedding_a * embedding_b), dim=-1)
    denominator = torch.sqrt((embedding_a**2).sum(dim=-1)) * torch.sqrt((embedding_b**2).sum(dim=-1))

    cosine = numerator / denominator

    return cosine


def pairwise_layer_cosine_similarity(layers):

    layers = torch.stack(layers)

    layer_a = layers.unsqueeze(0)
    layer_b = layers.unsqueeze(1)

    numerator = torch.sum((layer_a * layer_b), dim=-1)
    denominator = torch.sqrt((layer_a**2).sum(dim=-1)) * torch.sqrt((layer_b**2).sum(dim=-1))

    cosine = (numerator / denominator).mean(dim=-1)

    return cosine


def pairwise_layer_MSE(layers):

    layers = torch.stack(layers)

    layer_a = layers.unsqueeze(0)
    layer_b = layers.unsqueeze(1)

    squared_diff = ((layer_a - layer_b) ** 2).mean(dim=-1)
    mse = squared_diff.mean(dim=-1)

    return mse


def pairwise_layer_CKA_similarity(layers):

    cka = torch.zeros([len(layers), len(layers)])

    for i in range(len(layers)):
        for j in range(i, len(layers)):
            similarity = cka_fn(layers[i].to(device), layers[j].to(device), hsic=linear_hsic)  # kernel_hsic
            cka[i, j] = similarity
            cka[j, i] = similarity

    return cka


def pairwise_layer_SVCCA_similarity(layers):

    svcca = torch.zeros([len(layers), len(layers)]).to(device)

    for i in range(len(layers)):
        for j in range(i, len(layers)):
            similarity = svcca_fn(layers[i].to(device), layers[j].to(device), tolerance=1e-3)
            svcca[i, j] = similarity
            svcca[j, i] = similarity

    return svcca


def pairwise_PRESTO_score(layers):

    presto = Presto(projector=Gauss)
    presto_matrix = torch.zeros([len(layers), len(layers)])

    for i in range(len(layers)):
        for j in range(i, len(layers)):
            dist = presto.fit_transform(layers[i], layers[j], n_projections=20, n_components=2, normalize=True)
            presto_matrix[i, j] = dist
            presto_matrix[j, i] = dist

    return presto_matrix
