import torchvision.models as models
import torch
from nesim.utils.json_stuff import load_json_as_dict
from nesim.utils.getting_modules import get_module_by_name
from nesim.utils.hook import ForwardHook
from nesim.utils.grid_size import find_rectangle_dimensions
import matplotlib.pyplot as plt
import os
from tqdm import tqdm

## local imports from utils.py
from utils import (
    CuratedDataset,
    CategorySelectivityMapExperiment,
    find_most_activating_label_avg,
)


# checkpoints_root =  "/research/XXXX-1/nesim/training/imagenet/resnet18/checkpoints/imagenet/" ## barlow
checkpoints_root = "/mindhive/nklab3/users/XXXX-1/repos/nesim/training/imagenet/resnet50/checkpoints/imagenet"  ## openmind

dataset_root = "./datasets/clip_retrieval_dataset"
device = "cuda:0"
layer_names_filename = "/mindhive/nklab3/users/XXXX-1/repos/nesim/training/imagenet/resnet50/possible_nesim_layers.json"

model = models.resnet50(weights=None)


baseline_checkpoint_filename = os.path.join(
    checkpoints_root,
    "baseline_shrink_factor_[5.0]_loss_scale_None_layers_all_conv_layers__bimt_scale_None_from_pretrained_False_apply_every_20_steps",
    "best/best_model.ckpt",
)

our_checkpoint_filename = os.path.join(
    checkpoints_root,
    "shrink_factor_[5.0]_loss_scale_150_layers_layer3__bimt_scale_None_from_pretrained_False_apply_every_20_steps",
    "best/best_model.ckpt",
)

checkpoint_map = {
    "baseline": baseline_checkpoint_filename,
    "ours": our_checkpoint_filename,
}

for path in checkpoint_map.values():
    assert os.path.exists(path)

checkpoint_map = {
    "baseline": baseline_checkpoint_filename,
    "ours": our_checkpoint_filename,
}
target_layer_names = load_json_as_dict(layer_names_filename)["layer3"]

import matplotlib.colors as mcolors
import matplotlib.pyplot as plt

threshold = 0.0

folders = [
    "bird",
    "car",
    "fish",
    "human_body",
    "human_face",
    "scene_indoor",
    "scene_outdoor",
    "text",
]


folder_colors = {
    "bird": "#FF5733",
    "car": "#3366FF",
    "fish": "#33FF33",
    "human_body": "#FF33FF",
    "human_face": "#FF9933",
    "scene_indoor": "#9933FF",
    "scene_outdoor": "#FF6633",
    "text": "#33FFFF",
}


dataset = CuratedDataset(folders=[os.path.join(dataset_root, f) for f in folders])

# Create and run the baseline experiment
experiment_baseline = CategorySelectivityMapExperiment(
    model=model,
    dataset=dataset,
    target_layer_names=target_layer_names,
    checkpoint_filename=checkpoint_map["baseline"],
    device="cuda:0",
)
experiment_baseline.run()

# # Create and run our experiment
experiment_ours = CategorySelectivityMapExperiment(
    model=model,
    dataset=dataset,
    target_layer_names=target_layer_names,
    checkpoint_filename=checkpoint_map["ours"],
    device="cuda:0",
)
experiment_ours.run()

if threshold > 0:
    print("Threshold is not 0")
    folder_colors["none"] = "#000000"
    ytick_labels = folders + ["none"]

else:
    ytick_labels = folders

# Define your custom colors for each integer
colors = [folder_colors[f] for f in folder_colors]

# Create a ListedColormap with the specified colors
cmap = mcolors.ListedColormap(colors)


images_ours = []
images_baseline = []

count = 0
# Visualize the results for each layer
# target_layer_name_good = ['layer3.0.conv3', 'layer3.1.conv3']
for name in tqdm(target_layer_names):

    # Create subplots for the two experiments (baseline and ours)
    fig, ax = plt.subplots(nrows=1, ncols=2, figsize=(15, 6))
    fig.suptitle(f"Category selectivity map for layer: {name}\nThreshold: {threshold}")

    # Baseline experiment
    all_outputs_for_single_layer_baseline = (
        experiment_baseline.get_all_outputs_for_single_layer(layer_name=name)
    )
    grid_size = find_rectangle_dimensions(
        all_outputs_for_single_layer_baseline.shape[1]
    )
    (
        most_activating_labels_avg_baseline,
        pureness_heatmap_baseline,
    ) = find_most_activating_label_avg(
        all_outputs_for_single_layer_baseline, dataset.labels, threshold=threshold
    )
    pureness_heatmap_baseline = pureness_heatmap_baseline.reshape(
        grid_size.height, grid_size.width
    ).numpy()
    # Plot the baseline result
    image_baseline = most_activating_labels_avg_baseline.reshape(
        grid_size.height, grid_size.width
    ).numpy()
    im = ax[0].imshow(image_baseline, cmap=cmap)
    ax[0].axis("off")
    ax[0].set_title("Baseline")
    cbar = plt.colorbar(im, ticks=[i for i in range(len(folder_colors))], ax=ax[0])
    cbar.ax.set_yticklabels(ytick_labels, rotation=0)
    cbar.set_label("Categories")
    images_baseline.append(image_baseline)

    # # Our experiment
    all_outputs_for_single_layer_ours = (
        experiment_ours.get_all_outputs_for_single_layer(layer_name=name)
    )
    grid_size = find_rectangle_dimensions(all_outputs_for_single_layer_ours.shape[1])
    (
        most_activating_labels_avg_ours,
        pureness_heatmap_ours,
    ) = find_most_activating_label_avg(
        all_outputs_for_single_layer_ours, dataset.labels, threshold=threshold
    )

    pureness_heatmap_ours = pureness_heatmap_ours.reshape(
        grid_size.height, grid_size.width
    ).numpy()
    # Plot our result
    image_ours = most_activating_labels_avg_ours.reshape(
        grid_size.height, grid_size.width
    ).numpy()
    im = ax[1].imshow(image_ours, cmap=cmap)
    ax[1].axis("off")
    ax[1].set_title("Ours")
    cbar = plt.colorbar(im, ticks=[i for i in range(len(folder_colors))], ax=ax[1])
    cbar.ax.set_yticklabels(ytick_labels, rotation=0)
    cbar.set_label("Categories")
    images_ours.append(image_ours)

    fig.savefig(f"temp/{count}.jpg")
    count += 1
