import matplotlib.pyplot as plt
from PIL import Image
from nesim.utils.json_stuff import load_json_as_dict
import os
import numpy as np
import scipy
import torch
from nesim.utils.figure.figure_1 import apply_ratan_matplotlib_thing
apply_ratan_matplotlib_thing()

model_names = [
    ## all_topo
    "end_topo_scale_1.0_shrink_factor_3.0",
    "end_topo_scale_5.0_shrink_factor_3.0",
    "end_topo_scale_10.0_shrink_factor_3.0",
    "end_topo_scale_50.0_shrink_factor_3.0",
    ## end_topo
    "all_topo_scale_1_shrink_factor_3.0",
    "all_topo_scale_5_shrink_factor_3.0",
    "all_topo_scale_10.0_shrink_factor_3.0",
    "all_topo_scale_50.0_shrink_factor_3.0",
    ## baseline
    "baseline_scale_None_shrink_factor_3.0",
]


clip_values = False
colorbar = True
metric = "selectivity"
center_colorbar_to_zero = True
cmaps = {
    "small-big": "PiYG_r",
    "animate-inanimate": "RdBu_r"
}

for mode in ["small-big", "animate-inanimate"]:
    for model_name in model_names:


        layer_names = load_json_as_dict(
            "../../../../training/imagenet/resnet18/layer_names.json"
        )

        if model_name.startswith("end_topo"):
            topo_layer_names = layer_names["last_conv_layers_in_each_block"]
        elif model_name.startswith("all_topo"):
            topo_layer_names = layer_names["all_conv_layers"]
        else:
            ## baseline
            topo_layer_names = layer_names["all_conv_layers"]
        
        
        single_model_results_folder = os.path.join(
            "results",
            model_name,
            mode,

        )
        assert os.path.exists(single_model_results_folder)

        fig, ax = plt.subplots(nrows=1, ncols=len(topo_layer_names), figsize = (len(topo_layer_names)*5 , 5))

        fig.suptitle(f"model:{model_name}\nmode: {mode}\nmetric: {metric}")
        for idx, layer_name in enumerate(topo_layer_names):

            if mode == "small-big":
                image_filename = os.path.join(
                    single_model_results_folder,
                    layer_name,
                    f"tval_Small.png"
                )
                map_filename = os.path.join(
                    single_model_results_folder,
                    layer_name,
                    f"Small.mat"
                )
            else:
                image_filename = os.path.join(
                    single_model_results_folder,
                    layer_name,
                    f"tval_Animate.png"
                )
                map_filename = os.path.join(
                    single_model_results_folder,
                    layer_name,
                    f"Animate.mat"
                )
            tvals = scipy.io.loadmat(map_filename)[metric]

            if clip_values:
                tvals[tvals>1.] = 1.
                tvals[tvals<-1.] = -1.

            if center_colorbar_to_zero:
                vmin = -torch.tensor(tvals).abs().max()
                vmax = torch.tensor(tvals).abs().max()
                im = ax[idx].imshow(tvals, cmap=cmaps[mode], vmin=vmin, vmax=vmax)  # Store the result of imshow in a variable
            else:
                im = ax[idx].imshow(tvals, cmap=cmaps[mode])  # Store the result of imshow in a variable

            ax[idx].axis("off")
            ax[idx].set_title(f"layer: {layer_name}\nUp: {mode.split('-')[0]}")
            if colorbar:
                fig.colorbar(im, ax=ax[idx], fraction=0.046, pad=0.04)  # Add colorbar for the specific axis

                
        filename = os.path.join(
                "send_to_ratan",
                f"{metric}_{mode}_{model_name}.pdf"
            )
        fig.tight_layout()
        fig.savefig(
            filename
        )
        plt.close()
        print(filename)