import torch
import os
from tqdm import tqdm
from nesim.utils.tensor_size import get_tensor_size_string
from einops import rearrange
from nesim.utils.correlation import pearsonr
from torchtyping import TensorType
import matplotlib.pyplot as plt
import numpy as np
from scipy.ndimage.filters import gaussian_filter1d
from nesim.utils.grid_size import find_rectangle_dimensions
from nesim.utils.figure.figure_1 import apply_ratan_matplotlib_thing

apply_ratan_matplotlib_thing()
fontsize = 6
global_step = 10500
device = "cuda:0"
checkpoints_dir = "/home/penfield/repos/nesim/training/gpt_neo_125m/checkpoints"
hook_outputs_folder = "/research/XXXX-1/toponets_hook_outputs_gpt_neo_125m"

topo_layer_names = [f"transformer.h.{i}.mlp.c_fc" for i in range(12)]

checkpoint_names = [
    "untrained",
    "baseline",
    "topo_1",
    "topo_5",
    "topo_10",
    "topo_50",
]

import torch

def euclidean_distance_tensor(height, width):
    # Generate grid of points
    x = torch.arange(height)
    y = torch.arange(width)
    xx, yy = torch.meshgrid(x, y, indexing='ij')
    # Flatten the grids
    points = torch.stack([xx.flatten(), yy.flatten()], dim=1)
    
    # Compute pairwise distance using broadcasting
    diff = points[:, None, :] - points[None, :, :]
    distance_tensor = torch.sqrt(torch.sum(diff**2, dim=-1))
    
    return distance_tensor


def correlation_matrix(tensor: TensorType["A", "B"]) -> TensorType["B", "B"]:
    # Subtract the mean from each column
    mean = torch.mean(tensor, dim=0, keepdim=True)
    tensor_centered = tensor - mean

    # Compute the covariance matrix
    cov_matrix = torch.mm(tensor_centered.T, tensor_centered) / (tensor.size(0) - 1)

    # Compute the standard deviation of each column
    std_dev = torch.sqrt(torch.diag(cov_matrix))

    # Compute the correlation matrix
    corr_matrix = cov_matrix / (std_dev[:, None] * std_dev[None, :])

    return corr_matrix

def get_random_subset(tensor, fraction: float):
    # Ensure the fraction is between 0 and 1
    if not (0 <= fraction <= 1):
        raise ValueError("Fraction must be between 0 and 1.")

    # Calculate the number of elements to select
    subset_size = int(fraction * tensor.size(0))

    # Generate random indices
    random_indices = torch.randperm(tensor.size(0))[:subset_size]

    # Return the subset
    return tensor[random_indices]

def pouya_smoothness_metric(correlations: np.array):
    return float(np.max(correlations) - np.min(correlations))

fig, ax = plt.subplots(nrows=len(topo_layer_names), ncols=len(checkpoint_names), figsize=(12.27, 17))

smoothness_data = {}
plot_data = {}

for column, checkpoint_name in enumerate(tqdm(checkpoint_names)):
    
    hook_outputs_folder_single_checkpoint = os.path.join(
        hook_outputs_folder,
        checkpoint_name
    )

    filenames = [
        os.path.join(hook_outputs_folder_single_checkpoint, f"{dataset_idx}.pth")
        for dataset_idx in range(len(os.listdir(hook_outputs_folder_single_checkpoint)))
    ][:10]
    loaded_hook_ouput_files = [
        torch.load(f, map_location="cpu")
        for f in tqdm(filenames, desc = "loading hook outputs")
    ]

    smoothness_data[checkpoint_name] = {}
    plot_data[checkpoint_name] = {}
    
    for layer_index, layer_name in enumerate(topo_layer_names):
        row=layer_index

        if row == 0:
            ax[row, column].set_title(checkpoint_name, fontsize=16)

        all_outputs_for_single_layer = []

        for loaded_hook_output in tqdm(loaded_hook_ouput_files, desc=f"Computing stuff for layer: {layer_name}"):
            tensor = loaded_hook_output[layer_name].cpu()
            tensor = rearrange(
                tensor=tensor,
                pattern = "batch sequence hidden -> (batch sequence) hidden"
            )
            # tensor.shape: batch, *
            all_outputs_for_single_layer.append(tensor)

        print(f"Concatenating all outputs...")
        all_outputs_for_single_layer = torch.cat(all_outputs_for_single_layer, dim=0)
        print(f"Tensor size: {get_tensor_size_string(all_outputs_for_single_layer)}")

        corr_matrix = correlation_matrix(tensor = all_outputs_for_single_layer)
        size = find_rectangle_dimensions(area = corr_matrix.shape[0])
        distance_matrix = euclidean_distance_tensor(
            height=size.height,
            width=size.width
        ).to(corr_matrix.device)

        distances = distance_matrix[torch.tril(torch.ones_like(corr_matrix), diagonal=-1).bool()]
        all_correlation_values = corr_matrix[torch.tril(torch.ones_like(corr_matrix), diagonal=-1).bool()]

        slice_size = 100_000
        x = distance_matrix[torch.tril(torch.ones_like(corr_matrix), diagonal=-1).bool()].reshape(-1).cpu()[:slice_size]
        y = corr_matrix[torch.tril(torch.ones_like(corr_matrix), diagonal=-1).bool()].reshape(-1).cpu()[:slice_size]

        # x = get_random_subset(x,fraction=0.1)
        # y = get_random_subset(y, fraction=0.1)

        ax[row, column].scatter(
            x = x,
            y = y,
            alpha = 0.01/4
        )
        ax[row, column].set_ylim(-1,1)
        moving_average = [
            corr_matrix[distance_matrix==distance].mean().item()
            for distance in np.unique(np.array(x))
        ]
        smoothness_data[checkpoint_name][layer_name] = pouya_smoothness_metric(
            correlations=moving_average
        )
        plot_data[checkpoint_name][layer_name] = moving_average
        moving_average = gaussian_filter1d(
            moving_average, sigma=2
        )
        # plt.grid()
        ax[row, column].plot(np.unique(np.array(x)), moving_average, c = "red", linestyle = "--")
        ax[row, column].set_xticks(
            ticks = [0, 60],
            labels = [0,60],
            fontsize=fontsize
        )
        ax[row, column].set_yticks(
            ticks = [-1, 0 ,1],
            labels = [-1,0,1],
            fontsize=fontsize
        )
        if column == 0:
            ax[row,column].set_ylabel("Correlation")

        if row == 11:
            ax[row,column].set_xlabel("Distance")
        # plt.xlabel(f"Distance", fontsize = 18)
        # plt.ylabel(f"Correlation", fontsize = 18)
        # Remove the top and right spines
        # ax = plt.gca()
        ax[row, column].spines['top'].set_visible(False)
        ax[row, column].spines['right'].set_visible(False)


filename = f"assets/supplementary.png"
fig.savefig(filename)
print(f"Saved: {filename}")
filename = f"assets/supplementary.pdf"
fig.savefig(filename)
plt.close()
print(f"Saved: {filename}")
from nesim.utils.json_stuff import dict_to_json
dict_to_json(smoothness_data, "smoothness_data.json")

dict_to_json(
    plot_data,
    "result.json"
)