import argparse
import torch
import os
from nesim.utils.json_stuff import load_json_as_dict
from tqdm import tqdm
import matplotlib.pyplot as plt
import numpy as np
from nesim.utils.grid_size import find_rectangle_dimensions
import matplotlib.colors as mcolors

from nesim.utils.figure.figure_1 import get_model_and_tokenizer, CategorySelectivity, obtain_hook_outputs

import matplotlib as mpl
mpl.rcParams['pdf.fonttype'] = 42
mpl.rcParams['ps.fonttype'] = 42

parser = argparse.ArgumentParser(
    description="Generate a map of top activating category for each neuron in cortical sheet"
)
parser.add_argument(
    "--checkpoint-filename",
    type=str,
    help="Path to the model checkpoint file",
    required=True,
)
parser.add_argument(
    "--dataset-filename",
    type=str,
    help="Path to the dataset json containing dataset and color info",
    required=True,
)
parser.add_argument(
    "--output-filename", type=str, help="filename of saved image", required=True
)
##
parser.add_argument(
    "--layer-name",
    type=str,
    help="name of layer to visualize",
    required=True,
    default="transformer.h.10.mlp.c_fc",
)

parser.add_argument("--colorbar", action="store_true", help="Enable colorbar")
parser.add_argument("--colorbar-text", action="store_true", help="Enable colorbar text")

parser.add_argument("--device", type=str, default="cuda:0", help="device")

args = parser.parse_args()

assert os.path.exists(args.checkpoint_filename)
assert os.path.exists(args.dataset_filename)

data = load_json_as_dict(args.dataset_filename)
dataset = load_json_as_dict(args.dataset_filename)["dataset"]
colors = load_json_as_dict(args.dataset_filename)["colors"]

model, tokenizer = get_model_and_tokenizer(name="EleutherAI/gpt-neo-125m")

## load checkpoint for model
model.load_state_dict(
    torch.load(args.checkpoint_filename, map_location=torch.device("cpu"))
)
model.to(args.device)


hook_dict = obtain_hook_outputs(
    model,
    layer_names=[args.layer_name],
    dataset=dataset,
    tokenizer=tokenizer,
    device=args.device,
)

c = CategorySelectivity(dataset=dataset, hook_outputs=hook_dict)

top_activating_categories = []

for neuron_idx in tqdm(c.valid_neuron_indices):
    activations_for_idx = torch.zeros(len(c.categories))

    for class_idx, target_class in enumerate(c.categories):
        activation = c.konkle(
            neuron_idx=neuron_idx,
            target_class=target_class,
            other_classes=[x for x in dataset.keys() if x != target_class],
            layer_name=args.layer_name,
        )
        activations_for_idx[class_idx] = activation

    top_activating_categories.append(torch.argmax(activations_for_idx).item())

fontsize = 18
size = find_rectangle_dimensions(area=len(top_activating_categories))

cmap = mcolors.ListedColormap(list(colors.values()))
fig = plt.figure(figsize=(15, 8))

# fig.suptitle(
#     f"Top activating categories for each neuron\nModel: gpt-neo-125m\nLayer: {args.layer_name}",
#     fontsize=fontsize,
# )

im = np.array(top_activating_categories).reshape(size.height, size.width)
ax = plt.imshow(im, cmap=cmap)
plt.axis("off")

if args.colorbar:
    colorbar = plt.colorbar(ax, ticks=np.arange(len(colors)))

if args.colorbar_text:
    colorbar.set_ticklabels(list(colors.keys()))

fig.savefig(f"{args.output_filename}")
print(f"Saved\n{args.output_filename}")

"""
example:
python3 generate_map_multiple_categories.py --output-filename temp.jpg --dataset-filename dataset.json --layer-name "transformer.h.10.mlp.c_fc" --checkpoint-filename ../../../training/gpt_neo_125m/checkpoints/supreme_topo_scale_50/checkpoint-1700/pytorch_model.bin
"""
