from nesim.utils.correlation import pearsonr
import torch
import numpy as np
import os
import matplotlib.pyplot as plt

triplets = [
    {
        "target": "apple",
        "similar": "steve",
        "different": "tech"
    },
    {
        "target": "einstein",
        "similar": "physics",
        "different": "science"
    },
    {
        "target": "microsoft",
        "similar": "apple",
        "different": "science"
    }
]

checkpoint_names = [
    # "topo_1",
    # "topo_5",
    # "topo_10",
    "topo_50"
]

fig = plt.figure(figsize=(8,6))


for triplet in triplets:
    deltas = []

    for checkpoint_name in checkpoint_names:
        for layer_index in range(12):



            filenames= {
                triplet["target"]: os.path.join(
                    "assets",
                    checkpoint_name,
                    f"transformer.h.{layer_index}.mlp.c_fc/dprime",
                    f"{triplet['target']}.npy"
                ),
                triplet["similar"]: os.path.join(
                    "assets",
                    checkpoint_name,
                    f"transformer.h.{layer_index}.mlp.c_fc/dprime",
                    f"{triplet['similar']}.npy"
                ),
                triplet["different"]: os.path.join(
                    "assets",
                    checkpoint_name,
                    f"transformer.h.{layer_index}.mlp.c_fc/dprime",
                    f"{triplet['different']}.npy"
                )
            }

            dprime_values= {}

            for key in filenames:
                array = np.load(filenames[key])
                dprime_values[key] = torch.tensor(array.reshape(1, -1))

            # for x, y in comparisons:
                # print(f"Layer: {layer_index}| {x} | {y} |corr: {pearsonr(dprime_values[x], dprime_values[y]).item()}")
                # print("\n\n")

                
            a = pearsonr(
                dprime_values[triplet["target"]], 
                dprime_values[triplet["similar"]]
            ).item()
            b = pearsonr(
                dprime_values[triplet["target"]], 
                dprime_values[triplet["different"]]
            ).item()
            delta  = a - b
            print(f"Layer: {layer_index} Delta = {delta}")
            deltas.append(delta)

        label  = f"corr({triplet['target']}, {triplet['similar']}) -  corr({triplet['target']}, {triplet['different']})"
        plt.plot(deltas, label = label)
        plt.scatter(range(12), deltas)

plt.ylabel(f"$\Delta$ Selectivity", fontsize=18)
plt.yticks(fontsize=15)
plt.xticks(fontsize=15)
plt.ylim(ymax=1.45)
plt.legend()
# Move legend just above the plot
# plt.legend(loc='upper left', bbox_to_anchor=(0.0, 1.25), fontsize=10, ncol=1) 

plt.grid()
plt.xlabel("layer ->", fontsize=18)

# Save the figure
fig.savefig(f"deltas.png")