from nesim.utils.getting_modules import get_module_by_name
from nesim.utils.json_stuff import load_json_as_dict
from nesim.experiments.resnet import create_model_and_scaler
from nesim.eval.resnet import EvalSuite
from nesim.utils.grid_size import find_rectangle_dimensions
from nesim.utils.folder import make_folder_if_does_not_exist
from nesim.eval.resnet import load_resnet18_checkpoint

import torch
import os
import numpy as np
import matplotlib.pyplot as plt

import argparse
from torch.utils.data import DataLoader
from nesim.experiments.size_animacy import BigSmallDataset
import matplotlib.colors as mcolors
import scipy

def plot_and_save(x, model_name, layer_name, label_name, prefix, save_folder, identifier, center_cmap_to_zero = True):
    fig, ax = plt.subplots()  # Create a figure and axes
    cmap = "RdBu_r"
    if center_cmap_to_zero:
        max_val = torch.tensor(x).abs().max()
        im = ax.imshow(x, cmap=cmap, vmin=-max_val, vmax=max_val)
    else:
        im = ax.imshow(x, cmap=cmap)
    # XXXX
    cbar = fig.colorbar(im, ax=ax, fraction=0.046, pad=0.04)  # Link the colorbar to the image axes
    
    # cbar.ax.set_ylabel('Colorbar Label')  # Optional: Set a label for the colorbar

    # plt.title(f"{model_name}:{layer_name}\nred: {label_name}")
    plt.axis('off')
    
    filename = os.path.join(
        save_folder,
        f"{prefix}_{identifier}.png"
    )

    fig.savefig(filename, bbox_inches='tight')  # Use bbox_inches='tight' for better layout
    plt.close()

dataset = BigSmallDataset(
    folder = './dataset/BigSmallObjects',
)
val_dataloader = DataLoader(dataset=dataset)

eval_suite = EvalSuite(
    dataloader=val_dataloader,
)

layer_names = load_json_as_dict(
     "../../../../training/imagenet/resnet18/layer_names.json"
)

model_names = [
    ## all_topo
    "end_topo_scale_1.0_shrink_factor_3.0",
    "end_topo_scale_5.0_shrink_factor_3.0",
    "end_topo_scale_10.0_shrink_factor_3.0",
    "end_topo_scale_50.0_shrink_factor_3.0",
    ## end_topo
    "all_topo_scale_1_shrink_factor_3.0",
    "all_topo_scale_5_shrink_factor_3.0",
    "all_topo_scale_10.0_shrink_factor_3.0",
    "all_topo_scale_50.0_shrink_factor_3.0",
    ## baseline
    "baseline_scale_None_shrink_factor_3.0",

]
epoch = "final"

for model_name in model_names:
    model = load_resnet18_checkpoint(
        checkpoints_folder="/home/XXXX-4/repos/nesim/training/imagenet/resnet18/checkpoints",
        model_name=model_name,
        epoch=epoch
    )
    model.eval()

    if model_name.startswith("end_topo"):
        topo_layer_names = layer_names["last_conv_layers_in_each_block"]
        # print(f'END TOPO TAKING LAST LATER ONLY FOR DEBUGGING')
    elif model_name.startswith("all_topo"):
        topo_layer_names = layer_names["all_conv_layers"]
    else:
        ## baseline
        topo_layer_names = layer_names["last_conv_layers_in_each_block"]

    hook_outputs, labels = eval_suite.get_hook_outputs(
        model=model,
        layer_names=topo_layer_names,
        max_num_batches=None,
    )

    for label_name in dataset.label_names:

        for layer_name in topo_layer_names:
            try:
                num_out_channels = get_module_by_name(
                    module=model,
                    name= layer_name
                ).weight.shape[0]
            except AttributeError:
                num_out_channels = get_module_by_name(
                    module=model,
                    name= layer_name
                ).conv.weight.shape[0]
            cortical_sheet_size = find_rectangle_dimensions(area = num_out_channels)


            save_folder = os.path.join(
                "results",
                model_name,
                layer_name,
            )
            make_folder_if_does_not_exist(save_folder)

            identifier = f"{label_name}"

            selectivity_values = []

            selectivity_values = eval_suite.compute_selectivity_all_channels(
                hook_outputs=hook_outputs,
                layer_name=layer_name,
                labels=labels,
                target_classes=[dataset.label_names.index(label_name)],
                other_classes=[x for x in range(len(dataset.labels)) if x != dataset.label_names.index(label_name)],
                device = "cuda:0"
            )
            tvals, pvals = eval_suite.compute_tvals(
                hook_outputs=hook_outputs,
                layer_name=layer_name,
                labels=labels,
                target_classes=[dataset.label_names.index(label_name)],
                other_classes=[x for x in range(len(dataset.labels)) if x != dataset.label_names.index(label_name)],
                device = "cuda:0"
            )
            selectivity_values = np.array(selectivity_values).reshape(cortical_sheet_size.height, cortical_sheet_size.width)

            plot_and_save(
                x = selectivity_values,
                model_name=model_name,
                layer_name=layer_name,
                label_name=label_name,
                prefix = "selectivity",
                save_folder=save_folder,
                identifier=identifier
            )
            tvals = tvals.reshape(cortical_sheet_size.height, cortical_sheet_size.width)

            plot_and_save(
                x = tvals,
                model_name=model_name,
                layer_name=layer_name,
                label_name=label_name,
                prefix = "tval",
                save_folder=save_folder,
                identifier=identifier
            )


            delta_values = eval_suite.compute_delta_all_channels(
                hook_outputs=hook_outputs,
                layer_name=layer_name,
                labels=labels,
                target_classes=[dataset.label_names.index(label_name)],
                other_classes=[x for x in range(len(dataset.labels)) if x != dataset.label_names.index(label_name)],
                device = "cuda:0"
            )
            delta_values = np.array(delta_values).reshape(cortical_sheet_size.height, cortical_sheet_size.width)
            
            plot_and_save(
                x = delta_values,
                model_name=model_name,
                layer_name=layer_name,
                label_name=label_name,
                prefix = "delta",
                save_folder=save_folder,
                identifier=identifier
            )

            mat_filename = os.path.join(
                save_folder,
                f"{identifier}.mat"
            )
            scipy.io.savemat(
                mat_filename,
                {"selectivity": selectivity_values, "delta": delta_values, 'tvals': tvals}
            )
            print(f"saved:", mat_filename)
