"""
inspired by:
nesim/experiments/gpt_neo_125m/catergory_selectivity/generate_map_hierarchial.py
"""

import argparse
import torch
import os
from nesim.utils.json_stuff import load_json_as_dict
from tqdm import tqdm
import matplotlib.pyplot as plt
import numpy as np
from nesim.utils.grid_size import find_rectangle_dimensions
from nesim.utils.figure.figure_1 import exponential_colormap
from nesim.utils.checkpoint import get_checkpoint_path_gpt_neo_125m

from nesim.utils.figure.figure_1 import CategorySelectivity, obtain_hook_outputs
from nesim.utils.folder import make_folder_if_does_not_exist


parser = argparse.ArgumentParser(
    description="Generate a map of top activating category for each neuron in cortical sheet"
)

parser.add_argument(
    "--dataset-filename",
    type=str,
    help="Path to the dataset json containing dataset and color info",
    required=True,
)

parser.add_argument(
    "--target-class", type=str, help="name of category from dataset", required=True
)

parser.add_argument("--device", type=str, default="cuda:0", help="device")
parser.add_argument("--mode", type=str, default="dprime", help="softmax or dprime")

parser.add_argument("--layer-index", type=int, default=None, help="softmax or dprime")

parser.add_argument("--checkpoint-name", type=str, default=None, help="softmax or dprime")
parser.add_argument("--other-classes", nargs='+', help='List of strings')

args = parser.parse_args()


assert args.mode in ["softmax", "dprime"]

assert os.path.exists(args.dataset_filename)

output_folder = "./assets"
topo_layer_names = [f"transformer.h.{i}.mlp.c_fc" for i in range(12)]

if args.layer_index is not None:
    topo_layer_names = [f"transformer.h.{args.layer_index}.mlp.c_fc"]
else:
    topo_layer_names = [f"transformer.h.{i}.mlp.c_fc" for i in range(12)]

topo_scales = [50, 10,1,5]
checkpoint_dir = "/home/XXXX-4/repos/nesim/training/gpt_neo_125m/checkpoints"
global_step = 10500
checkpoints_map = {
    # "untrained": None,
    # "pretrained": "pretrained",
    "baseline": get_checkpoint_path_gpt_neo_125m(
        checkpoints_dir=checkpoint_dir, 
        topo_scale=0, 
        global_step=global_step
    ),
}

for topo_scale in topo_scales:

    checkpoints_map[f"topo_{topo_scale}"] = get_checkpoint_path_gpt_neo_125m(
        checkpoints_dir=checkpoint_dir, 
        topo_scale=topo_scale, 
        global_step=global_step
    )


data = load_json_as_dict(args.dataset_filename)
dataset = {}
dataset[args.target_class] = load_json_as_dict(args.dataset_filename)

if args.other_classes is None:
    dataset["other"] = load_json_as_dict(filename="datasets/everything_else.json")
else:
    all_data = []
    for name in args.other_classes:
        d = load_json_as_dict(filename=os.path.join("datasets", f"{name}.json"))
        all_data.extend(d)
        
    dataset["other"] = all_data

from nesim.experiments.gpt_neo_125m import get_checkpoint

device = "cuda:0"

checkpoint_names = [args.checkpoint_name] if args.checkpoint_name is not None else list(checkpoints_map.keys())

for checkpoint_name in checkpoint_names:
    model, tokenizer = get_checkpoint(checkpoint_filename=checkpoints_map[checkpoint_name], device=device)
        
    model.to(args.device)

    for layer_name in topo_layer_names:
        hook_dict = obtain_hook_outputs(
            model,
            layer_names=[layer_name],
            dataset=dataset,
            tokenizer=tokenizer,
            device=args.device,
        )
        """
        hook_dict looks like this:
        {
            layer_name: {
                science.math: [
                    tensor1, ## shape = 1, sequence_length, 3072
                    tensor2,
                    ... number of samples in dataset
                ]
            }
        }
        """

        c = CategorySelectivity(dataset=dataset, hook_outputs=hook_dict)

        target_category_values = []

        for neuron_idx in tqdm(c.valid_neuron_indices):
            activations_for_idx = torch.zeros(len(c.categories))

            if args.mode == "softmax":
                score = c.softmax_score(
                    neuron_idx=neuron_idx,
                    target_class=args.target_class,
                    layer_name=layer_name,
                    # other_classes=[x for x in dataset.keys() if x != args.target_class]
                )
            elif args.mode == "dprime":
                score = c.konkle(
                    neuron_idx=neuron_idx,
                    target_class=args.target_class,
                    other_classes=[x for x in dataset.keys() if x != args.target_class],
                    layer_name=layer_name,
                    mode="norm"
                )
                
            else:
                raise ValueError(f"Invalid args.mode: {args.mode}")
            target_category_values.append(score)

        fontsize = 25

        fig = plt.figure(figsize=(15, 11))

        size = find_rectangle_dimensions(area=len(c.valid_neuron_indices))
        im = np.array(target_category_values).reshape(size.height, size.width)

        # EXP cmap
        min_value = im.min()
        max_value = im.max()

        from matplotlib.colors import LinearSegmentedColormap

        # Define the custom color map
        colors = [(0, 0, 1),   # Blue for low values
                (1,1,1),  # white for middle values
                (1, 0, 0)]   # Red for high values

        custom_cmap = LinearSegmentedColormap.from_list('custom_cmap', colors, N=256)

        ax = plt.imshow(im, cmap=custom_cmap, vmin=min_value, vmax=max_value)

        
        plt.tight_layout()
        plt.axis("off")
        save_folder = os.path.join(
                output_folder,
                f"{checkpoint_name}",
                f"{layer_name}",
                f"{args.mode}"
            )
        make_folder_if_does_not_exist(
            folder = save_folder
        )

        image_filename = os.path.join(
                save_folder,
                f"{args.target_class}.png"
            )
        numpy_filename = os.path.join(
            save_folder,
            f"{args.target_class}.npy"
        )
        fig.savefig(
            image_filename,
            bbox_inches='tight',
            dpi = 300,
            pad_inches =0
        )
        print(f"Saved: {image_filename}")

        cbar = plt.colorbar(ax, fraction=0.046, pad=0.04)  # Adjust size and padding
        cbar.ax.tick_params(labelsize=fontsize, labelcolor='black')  # Customize tick labels
        cbar.outline.set_visible(False)

        image_filename_colorbar = os.path.join(
                save_folder,
                f"{args.target_class}_with_colorbar.png"
            )
        fig.savefig(
            image_filename_colorbar,
            bbox_inches='tight',
            dpi = 300,
            pad_inches =0
        )

        print(f"Saved: {image_filename_colorbar}")

        np.save(numpy_filename, im)

        print(f"Saved: {numpy_filename}")
        plt.close()
