import pandas as pd
from nesim.utils.checkpoint import get_checkpoint_path_gpt_neo_125m
from nesim.experiments.gpt_neo_125m import get_checkpoint
from nesim.utils.figure.figure_1 import obtain_hook_outputs
from nesim.utils.figure.figure_1 import CategorySelectivity, obtain_hook_outputs
from tqdm import tqdm
import torch
import argparse
from nesim.utils.grid_size import find_rectangle_dimensions
import matplotlib.colors as mcolors
import numpy as np
import matplotlib.pyplot as plt
import os
from nesim.utils.getting_modules import get_module_by_name
from nesim.utils.folder import make_folder_if_does_not_exist

global_step = 10500
checkpoint_dir = "/home/XXXX-4/repos/nesim/training/gpt_neo_125m/checkpoints"
df  = pd.read_csv("Dataset1_SWJN_Stimuli.csv")
layer_names = [
    f"transformer.h.{idx}.mlp.c_fc" for idx in range(0,12)
]
device = "cuda:0"
categories = [
    "JABBERWOCKY",
    "NONWORDS",
    "SENTENCES",
    "WORDS",
]

parser = argparse.ArgumentParser(
    description="Generate a map of top activating category for each neuron in cortical sheet"
)


parser.add_argument(
    "--output-folder", type=str, help="folder where images should be saved", required=True
)

args = parser.parse_args()


# label_names = df.condition.unique()
# labels = {}
# for idx, name in enumerate(label_names):
#     labels[name] = idx


def transform_dataset(df):
    label_names = df.condition.unique()
    new_df  = {}

    for column_name in label_names:
        new_df[column_name] = df[df.condition == column_name].stimulus_string.values

    return new_df

dataset = transform_dataset(df=df)

topo_scales = [1,5,10,50]

## load checkpoints
checkpoints_map = {
    # "untrained": None,
    # "pretrained": "pretrained",
    "baseline": get_checkpoint_path_gpt_neo_125m(
        checkpoints_dir=checkpoint_dir, 
        topo_scale=0, 
        global_step=global_step
    ),
}

for topo_scale in topo_scales:
    checkpoints_map[f"topo_{topo_scale}"] = get_checkpoint_path_gpt_neo_125m(
        checkpoints_dir=checkpoint_dir, 
        topo_scale=topo_scale, 
        global_step=global_step
    )


for layer_name in layer_names:
    for category in categories:

        for checkpoint_name in tqdm(checkpoints_map, desc = "Main loop"):
            model, tokenizer = get_checkpoint(checkpoints_map[checkpoint_name], device=device)
            num_output_neurons = get_module_by_name(module=model, name = layer_names[0]).weight.shape[0]

            print(f"Running inference for: {checkpoint_name}")
            hook_dict = obtain_hook_outputs(
                model,
                layer_names=layer_names,
                dataset=dataset,
                tokenizer=tokenizer,
                device=device,
            )

            c = CategorySelectivity(dataset=dataset, hook_outputs=hook_dict)

            target_category_values = []

            for neuron_idx in tqdm(c.valid_neuron_indices):
                activations_for_idx = torch.zeros(len(c.categories))

                score = c.konkle(
                    neuron_idx=neuron_idx,
                    target_class=category,
                    other_classes=[x for x in dataset.keys() if x != category],
                    layer_name=layer_name,
                    mode="norm"
                )
                target_category_values.append(score)

            fig = plt.figure(figsize=(15, 11))

            size = find_rectangle_dimensions(area=len(c.valid_neuron_indices))
            im = np.array(target_category_values).reshape(size.height, size.width)

            # EXP cmap
            min_value = im.min()
            max_value = im.max()

            from matplotlib.colors import LinearSegmentedColormap

            # Define the custom color map
            colors = [(0, 0, 1),   # Blue for low values
                    (1,1,1),  # white for middle values
                    (1, 0, 0)]   # Red for high values

            custom_cmap = LinearSegmentedColormap.from_list('custom_cmap', colors, N=256)

            ax = plt.imshow(im, cmap=custom_cmap, vmin=min_value, vmax=max_value)

            
            plt.tight_layout()
            plt.axis("off")

            save_folder = os.path.join(
                args.output_folder,
                checkpoint_name,
                category
            )
            make_folder_if_does_not_exist(
                save_folder
            )
            filename = os.path.join(
                save_folder,
                f"{layer_name}.png"
            )

            fig.savefig(f"{filename}", pad_inches = 0)
            print(f"Saved\n{filename}")

            numpy_filename = os.path.join(
                save_folder,
                f"{layer_name}.npy"
            )
            np.save(file = numpy_filename, arr = im)
            print(f"Saved: {numpy_filename}")