
from nesim.utils.getting_modules import get_module_by_name
from nesim.utils.setting_attr import setattr_pytorch_model
from nesim.utils.json_stuff import load_json_as_dict
from nesim.losses.laplacian_pyramid.loss import LaplacianPyramidLoss
from nesim.experiments.resnet import create_model_and_scaler, create_val_loader
from nesim.eval.resnet import EvalSuite
from nesim.utils.grid_size import find_rectangle_dimensions
from nesim.utils.folder import make_folder_if_does_not_exist

import torch
import os
from tqdm import tqdm
import numpy as np
import matplotlib.pyplot as plt
from wordnet_hierarchy import WordNetHierarchy
import argparse

wordnet_hierarchy = WordNetHierarchy(filename="wordnet_hierarchy_data.json")
 
def plot_and_save(x, model_name, layer_name, label_name, prefix, save_folder, identifier, no_pad = False, symmetric_colormap = False):
    fig = plt.figure()

    if symmetric_colormap:
        max_val = torch.tensor(x).abs().max()
        plt.imshow(x, cmap="RdBu_r", vmin = -max_val, vmax = max_val)
    else:
        plt.imshow(x, cmap="RdBu_r")

    if no_pad is False:
        plt.colorbar()
        plt.title(f"{model_name}:{layer_name}\nred: {label_name}")

    plt.axis('off')
    filename = os.path.join(
        save_folder,
        f"{prefix}_{identifier}.png"
    )

    if not no_pad:
        fig.savefig(
            filename
        )
    else:
        fig.savefig(
            filename,
            bbox_inches='tight',
            dpi = 300,
            pad_inches =0
        )
    plt.close()

# Create the argument parser
parser = argparse.ArgumentParser(description="Process an integer argument.")

# Add an integer argument
parser.add_argument('--model-name', type=str, help='name of model', required=True)

# Parse the arguments
args = parser.parse_args()
model_name = args.model_name

imagenet_class_names = load_json_as_dict("imagenet_class_names.json")

val_dataloader = create_val_loader(
        val_dataset="/research/datasets/imagenet_ffcv/val_500_0.50_90.ffcv",
        num_workers=32,
        batch_size=512,
        resolution=224, 
        distributed=False, 
        gpu = 0
)
eval_suite = EvalSuite(
    dataloader=val_dataloader,
)

layer_names = load_json_as_dict(
     "../../../../training/imagenet/resnet18/layer_names.json"
)


epoch = 159

model, scaler = create_model_and_scaler(
    "resnet18",
    pretrained = False,
    distributed=False,
    use_blurpool=True,
    gpu = 0
)

model.load_state_dict(
    torch.load(
        os.path.join(
            "/home/XXXX-4/repos/nesim/training/imagenet/resnet18/checkpoints",
            f"{model_name}",
            f"epoch_{epoch}.pt"
        ),
        weights_only=True
    )
)
model.eval()



if model_name.startswith("end_topo"):
    topo_layer_names = layer_names["last_conv_layers_in_each_block"]
elif model_name.startswith("all_topo"):
    raise NotImplementedError(f"There is a known bug for this. Check: XXXX")
    topo_layer_names = layer_names["all_conv_layers_except_first"]
else:
    ## baseline
    topo_layer_names = layer_names["last_conv_layers_in_each_block"]

hook_outputs, labels = eval_suite.get_hook_outputs(
    model=model,
    layer_names=topo_layer_names,
    max_num_batches=None,
)

all_parent_names = [
    "dog, domestic dog, Canis familiaris",
    "bird",
    "reptile, reptilian",
    "aquatic vertebrate",
    "primate",
    "electrical device",
    "clothing, article of clothing, vesture, wear, wearable, habiliment",
    "building, edifice"
]

for layer_name in topo_layer_names:

    for parent_name in all_parent_names:
        target_class_names = wordnet_hierarchy.get_class_names_from_parent_name(parent_name=parent_name)
        
        target_classes = [
            imagenet_class_names.index(target_class_name)
            for target_class_name in target_class_names
        ]
        other_classes = [
            x for x in range(1000)
            if x not in target_classes
        ]

        num_out_channels = get_module_by_name(
            module=model,
            name= layer_name
        ).weight.shape[0]
        cortical_sheet_size = find_rectangle_dimensions(area = num_out_channels)

        selectivity_values = eval_suite.compute_selectivity_all_channels(
            hook_outputs=hook_outputs,
            layer_name=layer_name,
            labels=labels,
            target_classes=target_classes,
            other_classes=other_classes,
            device = "cuda:0"
        )
        selectivity_values = np.array(selectivity_values).reshape(cortical_sheet_size.height, cortical_sheet_size.width)

        delta_values = eval_suite.compute_delta_all_channels(
            hook_outputs=hook_outputs,
            layer_name=layer_name,
            labels=labels,
            target_classes=target_classes,
            other_classes=other_classes,
            device = "cuda:0"
        )
        delta_values = np.array(delta_values).reshape(cortical_sheet_size.height, cortical_sheet_size.width)

        identifier = f"{parent_name.split(',')[0].replace(' ', '_')}"

        save_folder = os.path.join(
            "results",
            model_name,
            layer_name,
            identifier
        )
        make_folder_if_does_not_exist(save_folder)


        fig = plt.figure()
        plot_and_save(
            x=selectivity_values,
            model_name=model_name,
            label_name=parent_name,
            prefix="selectivity",
            save_folder=save_folder,
            identifier=identifier,
            layer_name=layer_name
        )

        plot_and_save(
            x=selectivity_values,
            model_name=model_name,
            label_name=parent_name,
            prefix="no_pad_selectivity",
            save_folder=save_folder,
            identifier=identifier,
            layer_name=layer_name,
            no_pad=True
        )

        plot_and_save(
            x = delta_values,
            model_name=model_name,
            layer_name=layer_name,
            label_name=parent_name,
            prefix = "delta",
            save_folder=save_folder,
            identifier=identifier
        )
        plot_and_save(
            x = delta_values,
            model_name=model_name,
            layer_name=layer_name,
            label_name=parent_name,
            prefix = "no_pad_delta",
            save_folder=save_folder,
            identifier=identifier,
            no_pad = True
        )
        
        np.save(
            os.path.join(
                save_folder,
                f"{identifier}.npy"
            ),
            selectivity_values
        )
        print(f"saved:", identifier)
