import pandas as pd
from nesim.utils.checkpoint import get_checkpoint_path_gpt_neo_125m
from nesim.experiments.gpt_neo_125m import get_checkpoint
from nesim.utils.figure.figure_1 import obtain_hook_outputs
from nesim.utils.figure.figure_1 import CategorySelectivity, obtain_hook_outputs
from tqdm import tqdm
import torch
import argparse
from nesim.utils.grid_size import find_rectangle_dimensions
import matplotlib.colors as mcolors
import numpy as np
import matplotlib.pyplot as plt
import os
from nesim.utils.getting_modules import get_module_by_name

global_step = 8900
checkpoint_dir = "/home/XXXX-4/repos/nesim/training/gpt_neo_125m/checkpoints"
df  = pd.read_csv("Dataset1_SWJN_Stimuli.csv")
layer_names = [
    f"transformer.h.{idx}.mlp.c_fc" for idx in range(0,12)
]
device = "cuda:0"
colors = {
    "JABBERWOCKY": "#C0392B",
    "NONWORDS": "#F1C40F",
    "SENTENCES": "#2980B9",
    "WORDS": "#2ECC71",
}

categories = [
    "JABBERWOCKY",
    # "NONWORDS",
    "SENTENCES",
    # "WORDS",
]

parser = argparse.ArgumentParser(
    description="Generate a map of top activating category for each neuron in cortical sheet"
)


parser.add_argument(
    "--output-filename", type=str, help="folder where images should be saved", required=True
)


args = parser.parse_args()


# label_names = df.condition.unique()
# labels = {}
# for idx, name in enumerate(label_names):
#     labels[name] = idx


def transform_dataset(df):
    label_names = df.condition.unique()
    new_df  = {}

    for column_name in label_names:
        new_df[column_name] = df[df.condition == column_name].stimulus_string.values

    return new_df

dataset = transform_dataset(df=df)

topo_scales = [50]

## load checkpoints
checkpoints_map = {
    # "untrained": None,
    # "pretrained": "pretrained",
    # "baseline": get_checkpoint_path_gpt_neo_125m(
    #     checkpoints_dir=checkpoint_dir, 
    #     topo_scale=0, 
    #     global_step=global_step
    # ),
}

# """
# results = {
#     "model_name": {
#         "jabberwocky": [area_layer1, area_layer2, ...],
#         "nonwords": [area_layer1, area_layer2, ...]
#     }
# }
# """

results  = {}

for topo_scale in topo_scales:

    checkpoints_map[f"topo_{topo_scale}"] = get_checkpoint_path_gpt_neo_125m(
        checkpoints_dir=checkpoint_dir, 
        topo_scale=topo_scale, 
        global_step=global_step
    )

for checkpoint_name in tqdm(checkpoints_map, desc = "Main loop"):
    results[checkpoint_name] = {}

    for category in categories:
        results[checkpoint_name][category] = []

    model, tokenizer = get_checkpoint(checkpoints_map[checkpoint_name], device=device)

    num_output_neurons = get_module_by_name(module=model, name = layer_names[0]).weight.shape[0]

    print(f"Running inference for: {checkpoint_name}")
    hook_dict = obtain_hook_outputs(
        model,
        layer_names=layer_names,
        dataset=dataset,
        tokenizer=tokenizer,
        device=device,
    )

    c = CategorySelectivity(dataset=dataset, hook_outputs=hook_dict)

    for layer_name in layer_names:
        top_activating_categories = []
        for neuron_idx in tqdm(range(num_output_neurons)):
            activations_for_idx = torch.zeros(len(c.categories))

            for class_idx, target_class in enumerate(c.categories):
                activation = c.konkle(
                    neuron_idx=neuron_idx,
                    target_class=target_class,
                    other_classes=[x for x in dataset.keys() if x != target_class],
                    layer_name=layer_name,
                )
                activations_for_idx[class_idx] = activation

            top_activating_categories.append(torch.argmax(activations_for_idx).item())

        top_activating_categories = np.array(top_activating_categories)
        for category_index, category in enumerate(categories):
            data = {
                    "layer_name": layer_name,
                    "category": category,
                    "category_index": category_index,
                    "area": float((top_activating_categories==categories.index(category)).astype(np.float32).sum())
                }
            results[checkpoint_name][category].append(
                data
            )
            print(data)
        
from nesim.utils.json_stuff import dict_to_json

dict_to_json(results, filename = args.output_filename)