from nesim.utils.correlation import pearsonr
import torch
import numpy as np
import os
import matplotlib.pyplot as plt
from pytorch_msssim import ssim, ms_ssim, SSIM, MS_SSIM

# def z_score_normalize(tensor: torch.Tensor) -> torch.Tensor:
#     mean = tensor.mean()
#     std = tensor.std()
    
#     # Ensure standard deviation is not zero to avoid division by zero
#     if std == 0:
#         raise ValueError("Standard deviation is zero. Z-score normalization cannot be applied.")
    
#     normalized_tensor = (tensor - mean) / std
#     return normalized_tensor

model_names = [
    # "baseline",
    # "topo_1",
    # "topo_5",
    "topo_10",
    "topo_50",
]


def compute_ssim(tensor1, tensor2):
    # Ensure tensors are of the same shape
    assert tensor1.shape == tensor2.shape, "Tensors must have the same shape"

    # # Normalize tensors to the range [0, 1]
    tensor1_normalized = (tensor1 - tensor1.min()) / (tensor1.max() - tensor1.min())
    tensor2_normalized = (tensor2 - tensor2.min()) / (tensor2.max() - tensor2.min())

    ssim_val = ssim(
        tensor1_normalized, tensor2_normalized, data_range=1, size_average=False
    )

    # compute the structure

    return ssim_val


def load_map(checkpoint_name, layer_index, category):
    filename = os.path.join(
        "assets",
        checkpoint_name,
        f"transformer.h.{layer_index}.mlp.c_fc/dprime",
        f"{category}.npy",
    )
    return torch.tensor(np.load(filename))


def get_ssim(a, b, checkpoint_name, layer_index=10):

    a = (
        load_map(checkpoint_name=checkpoint_name, layer_index=layer_index, category=a)
        .unsqueeze(0)
        .unsqueeze(0)
    )

    b = (
        load_map(checkpoint_name=checkpoint_name, layer_index=layer_index, category=b)
        .unsqueeze(0)
        .unsqueeze(0)
    )

    return round(compute_ssim(a, b).item(), 4)


def get_pearsonr(a, b, checkpoint_name, layer_index=10):

    a = (
        load_map(checkpoint_name=checkpoint_name, layer_index=layer_index, category=a)
        .reshape(-1)
        .unsqueeze(0)
    )

    b = (
        load_map(checkpoint_name=checkpoint_name, layer_index=layer_index, category=b)
        .reshape(-1)
        .unsqueeze(0)
    )

    return round(pearsonr(a, b).item(), 4)


comparisons = [
    ("science", "politics"),
    ("science", "history"),
    ("politics", "history"),
    ("science", "technology"),
]

for checkpoint_name in model_names:
    for layer_index in [8,9,10,11]:
        for a, b in comparisons:
            p = get_pearsonr(a,b, layer_index=layer_index, checkpoint_name=checkpoint_name)
            s = get_ssim(a,b, layer_index=layer_index, checkpoint_name=checkpoint_name)
            print(
                f"[model: {checkpoint_name} layer: {layer_index}] {a} | {b} | ssim: {s} pearson: {p}"
            )
