import torch
from latentis.measure.functional.cka import cka as cka_fn
from latentis.measure.functional.cka import kernel_hsic, linear_hsic
from latentis.measure.functional.svcca import robust_svcca as svcca_fn
import torch.nn.functional as F

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")


def pairwise_embedding_cosine_similarity(embeddings, limit=100):

    embeddings = torch.stack(embeddings)
    embeddings = embeddings[:, :limit]

    embedding_a = embeddings.unsqueeze(1)
    embedding_b = embeddings.unsqueeze(2)

    numerator = torch.sum((embedding_a * embedding_b), dim=-1)
    denominator = torch.sqrt((embedding_a**2).sum(dim=-1)) * torch.sqrt((embedding_b**2).sum(dim=-1))

    cosine = numerator / denominator

    return cosine


def pairwise_layer_cosine_similarity(layers):

    layers = torch.stack(layers)

    layer_a = layers.unsqueeze(0)
    layer_b = layers.unsqueeze(1)

    numerator = torch.sum((layer_a * layer_b), dim=-1)
    denominator = torch.sqrt((layer_a**2).sum(dim=-1)) * torch.sqrt((layer_b**2).sum(dim=-1))

    cosine = (numerator / denominator).mean(dim=-1)

    return cosine


# def pairwise_layer_cosine_similarity(layers):
#     layers_tensor = torch.stack(layers, dim=0)

#     layer_a = layers_tensor.unsqueeze(0)
#     layer_b = layers_tensor.unsqueeze(1)

#     cosine_per_token_sample = F.cosine_similarity(layer_a, layer_b, dim=-1)
#     mean_cosine = cosine_per_token_sample.mean(dim=(-1, -2))

#     return mean_cosine


def pairwise_layer_MSE(layers):

    layers = torch.stack(layers)

    layer_a = layers.unsqueeze(0)
    layer_b = layers.unsqueeze(1)

    squared_diff = ((layer_a - layer_b) ** 2).mean(dim=-1)
    mse = squared_diff.mean(dim=-1)

    return mse


def pairwise_layer_CKA_similarity(layers):

    cka = torch.zeros([len(layers), len(layers)])

    for i in range(len(layers)):
        for j in range(i, len(layers)):
            similarity = cka_fn(layers[i].to(device), layers[j].to(device), hsic=linear_hsic)  # kernel_hsic
            cka[i, j] = similarity
            cka[j, i] = similarity

    return cka


def pairwise_layer_SVCCA_similarity(layers):

    svcca = torch.zeros([len(layers), len(layers)]).to(device)

    for i in range(len(layers)):
        for j in range(i, len(layers)):
            similarity = svcca_fn(layers[i], layers[j], tolerance=1e-3)
            # similarity = svcca_fn(layers[i].to(device), layers[j].to(device), tolerance=1e-3)
            svcca[i, j] = similarity
            svcca[j, i] = similarity

    return svcca
