import torch
import seaborn as sns
from layskip.utils import similarities


def plot_similarity_matrix(layers, ax, title, metric, vmax):

    if metric == "cosine":
        similarity_matrix = similarities.pairwise_layer_cosine_similarity(layers)
    elif metric == "MSE":
        similarity_matrix = similarities.pairwise_layer_MSE(layers)
    elif metric == "CKA":
        similarity_matrix = similarities.pairwise_layer_CKA_similarity(layers)
    elif metric == "SVCCA":
        similarity_matrix = similarities.pairwise_layer_SVCCA_similarity(layers)
    elif metric == "PRESTO":
        similarity_matrix = similarities.pairwise_PRESTO_score(layers)
    else:
        raise ValueError(f"You need to provide a similarity metric")

    cmap = "viridis"
    if metric == "MSE":
        cmap = "viridis_r"

    sns.heatmap(
        similarity_matrix,
        annot=True,
        cmap=cmap,
        fmt=".2f",
        xticklabels=range(len(layers)),
        yticklabels=range(len(layers)),
        # ax=ax,
        vmax=vmax,
    )
    ax.set_title(title)
    ax.xaxis.set_ticks_position("top")


def plot_similarity_embeddings(embeddings, ax, title, e1, e2, color):

    sim_trajectories = similarities.pairwise_embedding_cosine_similarity(embeddings)

    y = sim_trajectories[:, e1, e2].tolist()
    x = torch.arange(len(y)).tolist()

    ax.plot(x, y, marker="o", color=color)
    ax.set_title(title)
    ax.set_xlabel("Layer")
    ax.set_ylabel("Distance")
