from nesim.utils.getting_modules import get_module_by_name
from nesim.utils.json_stuff import load_json_as_dict, dict_to_json
from nesim.experiments.resnet import create_model_and_scaler
from nesim.eval.resnet import EvalSuite
from nesim.utils.grid_size import find_rectangle_dimensions
from nesim.utils.folder import make_folder_if_does_not_exist
from nesim.experiments.resnet import create_model_and_scaler, create_val_loader
from nesim.eval.resnet import load_resnet18_checkpoint
import torch
import os
import matplotlib.pyplot as plt
from utils import euclidean_distance_tensor, correlation_matrix
import argparse
from einops import rearrange, reduce
from scipy.ndimage.filters import gaussian_filter1d
import numpy as np
from tqdm import tqdm
from nesim.eval.nsd import NSDStimuli
import torchvision.transforms as transforms
from torch.utils.data import DataLoader

device = "cuda:0"
imagenet_transforms = transforms.Compose(
    [
        transforms.ToTensor(),
        transforms.Resize(224),
        transforms.Normalize(
            mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]
        )
    ]
)

dataset = NSDStimuli(
    hdf5_filename="/research/datasets/nsd/nsd_stimuli.hdf5",
    transform=imagenet_transforms
)
val_dataloader = DataLoader(dataset=dataset, shuffle=True, batch_size=32)

eval_suite = EvalSuite(
    dataloader=val_dataloader,
)

layer_names = load_json_as_dict(
    "../../../../training/imagenet/resnet18/layer_names.json"
)

mode = "all_topo"
if mode == "end_topo":
    model_names = [
        "pretrained",
        "baseline_scale_None_shrink_factor_3.0",
        "end_topo_scale_0.5_shrink_factor_3.0",
        "end_topo_scale_1.0_shrink_factor_3.0",
        "end_topo_scale_5.0_shrink_factor_3.0",
        "end_topo_scale_10.0_shrink_factor_3.0",
        "end_topo_scale_50.0_shrink_factor_3.0",
    ]
else:
    model_names = [
        "pretrained",
        "baseline_scale_None_shrink_factor_3.0",
        "all_topo_scale_0.5_shrink_factor_3.0",
        "all_topo_scale_1_shrink_factor_3.0",
        "all_topo_scale_5_shrink_factor_3.0",
        "all_topo_scale_10.0_shrink_factor_3.0",
        "all_topo_scale_50.0_shrink_factor_3.0",
    ]

if mode == "end_topo":
    topo_layer_names = layer_names["last_conv_layers_in_each_block"]
else:
    topo_layer_names = layer_names["all_conv_layers_except_first"]

fig, ax = plt.subplots(
    nrows=len(topo_layer_names),
    ncols=len(model_names),
    figsize=(2*len(model_names), 2*len(topo_layer_names)),
)
# fig.suptitle(f"Mean Correlations")
plot_data = {}
for column, model_name in enumerate(model_names):
    plot_data[model_name] = {}

    if model_name =="pretrained":
        import torchvision.models as models
        model = models.resnet18(weights = "DEFAULT")
        model = model.eval().to(device=device).half()
    else:
        model = load_resnet18_checkpoint(
            checkpoints_folder="/home/mdeb6/repos/nesim/training/imagenet/resnet18/checkpoints",
            model_name=model_name,
            epoch="final"
        )

    if model_name == "pretrained" and mode == "all_topo":
        topo_layer_names_new = []
        for t in topo_layer_names:
            if t.endswith(f".conv"):
                ## remove the .conv from the end of t
                t = t.rstrip(".conv")
            topo_layer_names_new.append(t)

        topo_layer_names = topo_layer_names_new

    hook_outputs, labels = eval_suite.get_hook_outputs(
        model=model,
        layer_names=topo_layer_names,
        max_num_batches=200,
        progress = True,
        spatial_pooling="max"
    )
    for row, layer_name in tqdm(enumerate(topo_layer_names)):

        hook_outputs_single_layer = hook_outputs[layer_name]

        if hook_outputs_single_layer.ndim == 4:
            hook_outputs_single_layer = (
                rearrange(hook_outputs_single_layer, "b c h w -> (b h w) c")
                .to(device)
                .float()
            )

        corr_matrix = correlation_matrix(tensor=hook_outputs_single_layer)
        size = find_rectangle_dimensions(area=corr_matrix.shape[0])

        distance_matrix = euclidean_distance_tensor(
            height=size.height, width=size.width
        ).to(corr_matrix.device)

        distances = distance_matrix.reshape(-1)
        all_correlation_values = corr_matrix.reshape(-1)

        # Assuming distances and all_correlation_values are defined
        unique_distances = np.unique(distances.cpu())

        mean_correlations = [
            float(np.mean((all_correlation_values[distances == d].cpu()).numpy()))
            for d in unique_distances
        ]
        plot_data[model_name][layer_name] = mean_correlations

        # Calculate the standard error of the mean (SEM) for each unique distance
        sem_correlations = [
            np.std(all_correlation_values[distances == d].cpu().numpy()) / np.sqrt(np.sum((distances == d).cpu().numpy()))
            for d in unique_distances
        ]

        # Apply Gaussian smoothing to the mean correlations
        # mean_correlations = gaussian_filter1d(mean_correlations, sigma=2)

        # Prepare the x and y values for the plot
        x = unique_distances[1:]
        y = mean_correlations[1:]
        sem_correlations = sem_correlations[1:]

        # Plot the mean correlations with error bars for SEM
        ax[row, column].plot(
            x, 
            y, 
            # yerr=sem_correlations, 
            # # fmt='-o',  # Line with circle markers
            # capsize=3,  # Length of the error bar caps
            # alpha=0.9
        )

        ax[row, column].set_ylim(0, 1)
        ax[row, column].spines['top'].set_visible(False)
        ax[row, column].spines['right'].set_visible(False)

        fractions = [0,0.5, 1]
        distance_fractions = [
            distances.max().item()*f for f in fractions
        ]
        # Set x-ticks at distance fractions
        ax[row, column].set_xticks(distance_fractions)

        # Set x-tick labels as percentages
        ax[row, column].set_xticklabels([f"{int(f * 100)}%" for f in fractions])

        ax[row, column].set_yticks([0,1])
        ax[row, column].set_yticklabels([0,1])

        if column == 0:
            ax[row,column].set_ylabel(f"{layer_name}")

        if row == 0:
            if model_name.startswith("baseline"):
                title="baseline"
            elif model_name.startswith("pretrained"):
                title="pretrained"
            else:
                tau = float(model_name.split("_")[3])
                if tau > 1:
                    tau = int(tau)
                title = f"$\\tau = {tau}$"
            ax[row, column].set_title(f"{title}")

plt.tight_layout()
filename = f"assets/{mode}.png"
fig.savefig(filename)
print(f"Saved: {filename}")
filename = f"assets/{mode}.pdf"
fig.savefig(filename)
plt.close()
print(f"Saved: {filename}")


dict_to_json(
    dictionary=plot_data,
    filename="results.json"
)