
from nesim.utils.getting_modules import get_module_by_name
from nesim.utils.setting_attr import setattr_pytorch_model
from nesim.utils.json_stuff import load_json_as_dict
from nesim.losses.laplacian_pyramid.loss import LaplacianPyramidLoss
from nesim.experiments.resnet import create_model_and_scaler, create_val_loader
from nesim.eval.resnet import EvalSuite
from nesim.utils.grid_size import find_rectangle_dimensions
from nesim.utils.folder import make_folder_if_does_not_exist
from nesim.eval.resnet import load_resnet18_checkpoint

import torch
import os
from tqdm import tqdm
import numpy as np
import matplotlib.pyplot as plt

import argparse
from torch.utils.data import DataLoader
from floc_dataset import FlocDataset
import scipy 
def plot_and_save(x, model_name, layer_name, label_name, prefix, save_folder, identifier, center_cmap_to_zero = True):
    fig, ax = plt.subplots()  # Create a figure and axes
    cmap = "RdBu_r"
    if center_cmap_to_zero:
        max_val = torch.tensor(x).abs().max()
        im = ax.imshow(x, cmap=cmap, vmin=-max_val, vmax=max_val)
    else:
        im = ax.imshow(x, cmap=cmap)
    # plt.show()
    
    plt.title(f"{model_name}:{layer_name}\nred: {label_name}")
    plt.axis('off')
    cbar = fig.colorbar(im, ax=ax, fraction=0.046, pad=0.04)
    filename = os.path.join(
        save_folder,
        f"{prefix}_{identifier}.png"
    )

    fig.savefig(
        filename
    )
    plt.close()

# Create the argument parser
parser = argparse.ArgumentParser(description="Process an integer argument.")

# Add an integer argument
parser.add_argument('--target-class-idx', type=int, help='An integer number', required=False, default=None)

# Parse the arguments
args = parser.parse_args()

dataset = FlocDataset(folder = "dataset/images")
val_dataloader = DataLoader(dataset=dataset)

eval_suite = EvalSuite(
    dataloader=val_dataloader,
)

layer_names = load_json_as_dict(
     "../../../../training/imagenet/resnet18/layer_names.json"
)

model_names = [
    "end_topo_scale_1.0_shrink_factor_3.0",
    "end_topo_scale_5.0_shrink_factor_3.0",
    "end_topo_scale_10.0_shrink_factor_3.0",
    "end_topo_scale_50.0_shrink_factor_3.0",
    "baseline_scale_None_shrink_factor_3.0",
    "all_topo_scale_10.0_shrink_factor_3.0",
]
epoch = "final"

if args.target_class_idx is None:
    labels = dataset.labels
else:
    labels=[dataset.labels[args.target_class_idx]]



for model_name in model_names:
    model = load_resnet18_checkpoint(
        checkpoints_folder="/home/XXXX-4/repos/nesim/training/imagenet/resnet18/checkpoints",
        model_name=model_name,
        epoch=epoch
    )
    model.eval()

    if model_name.startswith("end_topo"):
        topo_layer_names = layer_names["last_conv_layers_in_each_block"]
    elif model_name.startswith("all_topo"):
        # raise NotImplementedError(f"There is a known bug for this. Buy XXXX-1 a coffee (or the hot chocolate at Kaldi's) and he will fix it for you. Check: XXXX")
        topo_layer_names = layer_names["all_conv_layers"]
        # topo_layer_names = [
        #     x for x in topo_layer_names if not x.endswith(".0")
        # ]
    else:
        ## baseline
        topo_layer_names = layer_names["last_conv_layers_in_each_block"]

    hook_outputs, labels = eval_suite.get_hook_outputs(
        model=model,
        layer_names=topo_layer_names,
        max_num_batches=None,
    )

    for label_name in dataset.labels:

        for layer_name in topo_layer_names:
            try:
                num_out_channels = get_module_by_name(
                    module=model,
                    name= layer_name
                ).weight.shape[0]
            except AttributeError:
                num_out_channels = get_module_by_name(
                    module=model,
                    name= layer_name
                ).conv.weight.shape[0]
            cortical_sheet_size = find_rectangle_dimensions(area = num_out_channels)

            selectivity_values = []

            selectivity_values = eval_suite.compute_selectivity_all_channels(
                hook_outputs=hook_outputs,
                layer_name=layer_name,
                labels=labels,
                target_classes=[dataset.labels.index(label_name)],
                other_classes=[x for x in range(len(dataset.labels)) if x != dataset.labels.index(label_name)],
                device = "cuda:0"
            )
            
            selectivity_values = np.array(selectivity_values).reshape(cortical_sheet_size.height, cortical_sheet_size.width)

            delta_values = eval_suite.compute_delta_all_channels(
                hook_outputs=hook_outputs,
                layer_name=layer_name,
                labels=labels,
                target_classes=[dataset.labels.index(label_name)],
                other_classes=[x for x in range(len(dataset.labels)) if x != dataset.labels.index(label_name)],
                device = "cuda:0"
            )
            delta_values = np.array(delta_values).reshape(cortical_sheet_size.height, cortical_sheet_size.width)

            save_folder = os.path.join(
                "results",
                model_name,
                layer_name,
            )
            make_folder_if_does_not_exist(save_folder)

            identifier = f"{label_name}"

            plot_and_save(
                x=selectivity_values,
                model_name=model_name,
                label_name=label_name,
                prefix="selectivity",
                save_folder=save_folder,
                identifier=identifier,
                layer_name=layer_name
            )

            plot_and_save(
                x = delta_values,
                model_name=model_name,
                layer_name=layer_name,
                label_name=label_name,
                prefix = "delta",
                save_folder=save_folder,
                identifier=identifier
            )

            mat_filename = os.path.join(
                save_folder,
                f"{identifier}.mat"
            )
            print(f"saved:", identifier)
            scipy.io.savemat(
                mat_filename,
                {"selectivity": selectivity_values, "delta": delta_values}
            )