from nesim.utils.correlation import pearsonr
import torch
import numpy as np
import os
import matplotlib.pyplot as plt
from pytorch_msssim import ssim, ms_ssim, SSIM, MS_SSIM


def compute_ssim(tensor1, tensor2):
    # Ensure tensors are of the same shape
    assert tensor1.shape == tensor2.shape, "Tensors must have the same shape"

    # # Normalize tensors to the range [0, 1]
    tensor1_normalized = (tensor1 - tensor1.min()) / (tensor1.max() - tensor1.min())
    tensor2_normalized = (tensor2 - tensor2.min()) / (tensor2.max() - tensor2.min())

    ssim_val = ssim(tensor1_normalized, tensor2_normalized, data_range=1, size_average=False)

   # compute the structure

    return ssim_val

def load_map(
    checkpoint_name,
    layer_index,
    category
):
    filename = os.path.join(
        "assets",
        checkpoint_name,
        f"transformer.h.{layer_index}.mlp.c_fc/dprime",
        f"{category}.npy"
    )
    return torch.tensor(np.load(filename))

def get_ssim(a, b, layer_index = 10):


    a = load_map(
        checkpoint_name="baseline",
        layer_index=layer_index,
        category=a
    ).unsqueeze(0).unsqueeze(0)

    b = load_map(
        checkpoint_name="baseline",
        layer_index=layer_index,
        category=b
    ).unsqueeze(0).unsqueeze(0)

    return compute_ssim(a,b).item()

comparisons = [
    ("science", "politics"),
    ("science", "history"),
    ("science", "technology"),
]

for a,b in comparisons:
    print(f"{a} | {b} | {get_ssim(a,b)}")