import torchvision.models as models
import torch
from nesim.utils.json_stuff import load_json_as_dict
from nesim.utils.getting_modules import get_module_by_name
from nesim.utils.hook import ForwardHook
from nesim.utils.grid_size import find_rectangle_dimensions
import matplotlib.pyplot as plt
import os
from tqdm import tqdm
import matplotlib.colors as mcolors
import matplotlib.pyplot as plt
import numpy as np

## local imports from utils.py
from utils import (
    CuratedDataset,
    CategorySelectivityMapExperiment,
    find_most_activating_label_avg,
)


# checkpoints_root =  "/research/XXXX-1/nesim/training/imagenet/resnet18/checkpoints/imagenet/" ## barlow
# checkpoints_root = "/mindhive/nklab3/users/XXXX-1/repos/nesim/training/imagenet/resnet18/checkpoints/imagenet"  ## openmind
checkpoints_root = (
    "/research/XXXX-3/repos/nesim/training/imagenet/resnet18/checkpoints/imagenet"
)

device = "cuda:0"
# layer_names_filename = "/mindhive/nklab3/users/XXXX-1/repos/nesim/training/imagenet/resnet18/possible_nesim_layers.json" ##openmind
layer_names_filename = "/research/XXXX-3/repos/nesim/training/cifar100_phasewise/possible_nesim_layers.json"

# model = models.resnet18(weights=None)

# run_name = "bs_1024_original_resnet_paper_replication_shrink_factor_[5.0]_loss_scale_150_layers_all_conv_layers__bimt_scale_None_from_pretrained_False_apply_every_30_steps_apply_sorted_weights_init_filename_None"
# our_checkpoint_filename = os.path.join(
#     checkpoints_root,
#     run_name,
#     "best/best_model.ckpt",
# )


model = models.resnet18(weights=None)
# layer_names_filename = "/mindhive/nklab3/users/XXXX-1/repos/nesim/training/imagenet/resnet50/possible_nesim_layers.json"
# checkpoints_root = "/mindhive/nklab3/users/XXXX-1/repos/nesim/training/imagenet/resnet50/checkpoints/imagenet"
our_checkpoint_filename = os.path.join(
    checkpoints_root,
    "shrink_factor_[5.0]_loss_scale_150_layers_layer3__bimt_scale_None_from_pretrained_False_apply_every_20_steps",
    "best/best_model.ckpt",
)
target_layer_names = load_json_as_dict(layer_names_filename)[
    "all_conv_layers_except_conv1"
]


lename = os.path.join(
    checkpoints_root,
    "shrink_factor_[5.0]_loss_scale_150_layers_all_conv_layers__bimt_scale_None_from_pretrained_False_apply_every_20_steps",
    "best/best_model.ckpt",
)
target_layer_names = load_json_as_dict(layer_names_filename)[
    "all_conv_layers_except_conv1"
]

###############################################################
# dataset_root = "./datasets/clip_retrieval_dataset"
# folders = [
#     "bird",
#     "car",
#     "fish",
#     "human_body",
#     "human_face",
#     "scene_indoor",
#     "scene_outdoor",
#     "text"
# ]


# folder_colors = {
#     "bird": "#FF5733",
#     "car": "#3366FF",
#     "fish": "#33FF33",
#     "human_body": "#FF33FF",
#     "human_face": "#FF9933",
#     "scene_indoor": "#9933FF",
#     "scene_outdoor": "#FF6633",
#     "text": "#33FFFF"
# }
###############################################################


dataset_root = "./datasets/curated"
folders = [
    "bodies",
    "faces",
    "objects",
    # "scenes",
    "scene_indoor",
    "scene_outdoor",
    "text",
]
folder_colors = {
    "bodies": "#FF5733",
    "faces": "#3498DB",
    "objects": "#27AE60",
    # "scenes": "#F1C40F",
    "scene_indoor": "#9933FF",
    "scene_outdoor": "#FF6633",
    "text": "#9B59B6",
}

dataset = CuratedDataset(folders=[os.path.join(dataset_root, f) for f in folders])

# # Create and run our experiment
experiment = CategorySelectivityMapExperiment(
    model=model,
    dataset=dataset,
    target_layer_names=target_layer_names,
    checkpoint_filename=our_checkpoint_filename,
    device="cuda:0",
)
experiment.run()


possible_modes = ["mean", "norm"]
mode = "norm"
assert mode in possible_modes


output_dir = "./maps/individual/resnet18"
os.system(f"mkdir -p {output_dir}")
os.system(f"rm {output_dir}/*.jpg")

for target_category in folders:
    for layer_name in experiment.all_hook_outputs:
        filename = os.path.join(output_dir, f"{layer_name}_{target_category}.jpg")
        ## activations.shape: (len(dataset), channels, height, width)
        activations = torch.stack(
            experiment.all_hook_outputs[layer_name], dim=0
        ).squeeze(1)

        ## activations.shape: (len(dataset), channels)
        if mode == "mean":
            activations = activations.mean(-1).mean(-1)
        elif mode == "norm":
            activations = torch.norm(activations, p=1, dim=(-1, -2))

        target_label_index = folders.index(target_category)

        activations_for_target = activations[
            np.array(dataset.labels) == target_label_index
        ]
        per_neuron_selectivity_score = activations_for_target.sum(0) / activations.sum(
            0
        )
        per_neuron_selectivity_score = per_neuron_selectivity_score**2

        size = find_rectangle_dimensions(area=per_neuron_selectivity_score.shape[0])
        fig = plt.figure(figsize=(10, 10))

        # XXXX
        plt.imshow(
            per_neuron_selectivity_score.reshape(size.height, size.width), cmap="Reds"
        )
        plt.colorbar()
        plt.title(
            f"Layer: {layer_name}\nCategory: {target_category}\n H W Reduction: {mode}",
            fontsize=24,
        )
        plt.show()
        fig.savefig(filename)
        print(f"Saved: {filename}")
