import torch
import os
from tqdm import tqdm
from nesim.utils.tensor_size import get_tensor_size_string
from einops import rearrange
from nesim.utils.correlation import pearsonr
from torchtyping import TensorType
import matplotlib.pyplot as plt
import numpy as np
from scipy.ndimage.filters import gaussian_filter1d
from nesim.utils.grid_size import find_rectangle_dimensions
from nesim.utils.figure.figure_1 import apply_ratan_matplotlib_thing

apply_ratan_matplotlib_thing()
fontsize = 13
global_step = 10500
device = "cuda:0"
checkpoints_dir = "/home/penfield/repos/nesim/training/gpt_neo_125m/checkpoints"
hook_outputs_folder = "../../experiments/gpt_neo_125m/effective_dimensionality/hook_outputs"
layer_index = 11
topo_layer_name = f"transformer.h.{layer_index}.mlp.c_fc"

checkpoint_names = [
    "untrained",
    "baseline",
    "topo_1",
    "topo_5",
    "topo_10",
    "topo_50",
]

labels = {
    "baseline": "Baseline",  # Red-Orange
    "topo_1": f"$\\tau = {1}$",
    "topo_5":  f"$\\tau = {5}$",
    "topo_10":  f"$\\tau = {10}$",
    "topo_50":  f"$\\tau = {50}$",
    "untrained": "Untrained",  # Red-Orange

}

import torch

def euclidean_distance_tensor(height, width):
    # Generate grid of points
    x = torch.arange(height)
    y = torch.arange(width)
    xx, yy = torch.meshgrid(x, y, indexing='ij')
    # Flatten the grids
    points = torch.stack([xx.flatten(), yy.flatten()], dim=1)
    
    # Compute pairwise distance using broadcasting
    diff = points[:, None, :] - points[None, :, :]
    distance_tensor = torch.sqrt(torch.sum(diff**2, dim=-1))
    
    return distance_tensor


def correlation_matrix(tensor: TensorType["A", "B"]) -> TensorType["B", "B"]:
    # Subtract the mean from each column
    mean = torch.mean(tensor, dim=0, keepdim=True)
    tensor_centered = tensor - mean

    # Compute the covariance matrix
    cov_matrix = torch.mm(tensor_centered.T, tensor_centered) / (tensor.size(0) - 1)

    # Compute the standard deviation of each column
    std_dev = torch.sqrt(torch.diag(cov_matrix))

    # Compute the correlation matrix
    corr_matrix = cov_matrix / (std_dev[:, None] * std_dev[None, :])

    return corr_matrix

def get_random_subset(tensor, fraction: float):
    # Ensure the fraction is between 0 and 1
    if not (0 <= fraction <= 1):
        raise ValueError("Fraction must be between 0 and 1.")

    # Calculate the number of elements to select
    subset_size = int(fraction * tensor.size(0))

    # Generate random indices
    random_indices = torch.randperm(tensor.size(0))[:subset_size]

    # Return the subset
    return tensor[random_indices]

fig, ax = plt.subplots(nrows=1, ncols=len(checkpoint_names), figsize = (13,3))

for column, checkpoint_name in enumerate(tqdm(checkpoint_names)):
    
    hook_outputs_folder_single_checkpoint = os.path.join(
        hook_outputs_folder,
        checkpoint_name
    )

    filenames = [
        os.path.join(hook_outputs_folder_single_checkpoint, f"{dataset_idx}.pth")
        for dataset_idx in range(len(os.listdir(hook_outputs_folder_single_checkpoint)))
    ][:5]
    loaded_hook_ouput_files = [
        torch.load(f, map_location="cpu", weights_only=True)
        for f in tqdm(filenames, desc = "loading hook outputs")
    ]

    
    row=0

    if row == 0:
        ax[column].set_title(labels[checkpoint_name], fontsize=fontsize)

    all_outputs_for_single_layer = []

    for loaded_hook_output in tqdm(loaded_hook_ouput_files, desc=f"Computing stuff for layer: {topo_layer_name}"):
        tensor = loaded_hook_output[topo_layer_name].cpu()
        # tensor.shape: batch, *
        all_outputs_for_single_layer.append(tensor)

    print(f"Concatenating all outputs...")
    all_outputs_for_single_layer = torch.cat(all_outputs_for_single_layer, dim=0)
    print(f"Tensor size: {get_tensor_size_string(all_outputs_for_single_layer)}")

    all_outputs_for_single_layer = rearrange(
        all_outputs_for_single_layer,
        "batch sequence hidden -> (batch sequence) hidden"
    ).to(device=device)
    corr_matrix = correlation_matrix(tensor = all_outputs_for_single_layer)
    size = find_rectangle_dimensions(area = corr_matrix.shape[0])
    distance_matrix = euclidean_distance_tensor(
        height=size.height,
        width=size.width
    ).to(corr_matrix.device)

    distances = distance_matrix[torch.tril(torch.ones_like(corr_matrix), diagonal=-1).bool()]
    all_correlation_values = corr_matrix[torch.tril(torch.ones_like(corr_matrix), diagonal=-1).bool()]

    slice_size = 100_000
    x = distance_matrix[torch.tril(torch.ones_like(corr_matrix), diagonal=-1).bool()].reshape(-1).cpu()[:slice_size]
    y = corr_matrix[torch.tril(torch.ones_like(corr_matrix), diagonal=-1).bool()].reshape(-1).cpu()[:slice_size]

    # x = get_random_subset(x,fraction=0.1)
    # y = get_random_subset(y, fraction=0.1)

    ax[column].scatter(
        x = x,
        y = y,
        alpha = 0.01
    )
    ax[column].set_ylim(-1,1)
    moving_average = [
        corr_matrix[distance_matrix==distance].mean().item()
        for distance in np.unique(np.array(x))
    ]
    moving_average = gaussian_filter1d(
        moving_average, sigma=2
    )
    # plt.grid()
    ax[column].plot(np.unique(np.array(x)), moving_average, c = "red", linestyle = "--")
    ax[column].set_xticks(
        ticks = [0, 60],
        labels = [0,60],
        fontsize=fontsize
    )
    ax[column].set_yticks(
        ticks = [-1, 0 ,1],
        labels = [-1,0,1],
        fontsize=fontsize
    )
    if column == 0:
        ax[column].set_ylabel("Correlation", fontsize=fontsize)

    ax[column].set_xlabel("Distance")
    # plt.xlabel(f"Distance", fontsize = 18)
    # plt.ylabel(f"Correlation", fontsize = 18)
    # Remove the top and right spines
    # ax = plt.gca()
    ax[column].spines['top'].set_visible(False)
    ax[column].spines['right'].set_visible(False)


# filename = f"single_layer.png"
# fig.savefig(filename)
# print(f"Saved: {filename}")
filename = f"single_layer.pdf"
fig.savefig(filename)
# plt.close()
print(f"Saved: {filename}")