import argparse
import torch
import os
from nesim.utils.json_stuff import load_json_as_dict
from tqdm import tqdm
import matplotlib.pyplot as plt
import numpy as np
from nesim.utils.grid_size import find_rectangle_dimensions
import matplotlib.colors as mcolors

from nesim.utils.figure.figure_1 import get_model_and_tokenizer, CategorySelectivity, obtain_hook_outputs

import matplotlib as mpl

mpl.rcParams["pdf.fonttype"] = 42
mpl.rcParams["ps.fonttype"] = 42

parser = argparse.ArgumentParser(description="Generate assets for figure 1")
parser.add_argument(
    "--baseline-checkpoint-filename",
    type=str,
    help="Path to the model checkpoint file",
    required=True,
)

parser.add_argument(
    "--topo-checkpoint-filename",
    type=str,
    help="Path to the model checkpoint file",
    required=True,
)

parser.add_argument(
    "--dataset-filename",
    type=str,
    help="Path to the dataset json containing dataset and color info",
    required=True,
)

parser.add_argument(
    "--baseline-output-filename",
    type=str,
    help="filename of saved image",
    required=True,
)

parser.add_argument(
    "--topo-output-filename", type=str, help="filename of saved image", required=True
)

parser.add_argument(
    "--layer-name",
    type=str,
    help="name of layer to visualize",
    required=True,
    default="transformer.h.10.mlp.c_fc",
)

parser.add_argument("--colorbar", action="store_true", help="Enable colorbar")
parser.add_argument("--colorbar-text", action="store_true", help="Enable colorbar text")

parser.add_argument("--device", type=str, default="cuda:0", help="device")

args = parser.parse_args()

assert os.path.exists(args.topo_checkpoint_filename)
assert os.path.exists(args.baseline_checkpoint_filename)
assert os.path.exists(args.dataset_filename)

data = load_json_as_dict(args.dataset_filename)
dataset = load_json_as_dict(args.dataset_filename)["dataset"]
colors = load_json_as_dict(args.dataset_filename)["colors"]

model, tokenizer = get_model_and_tokenizer(name="EleutherAI/gpt-neo-125m")

checkpoint_filenames = [args.topo_checkpoint_filename, args.baseline_checkpoint_filename]
output_filenames = [args.topo_output_filename, args.baseline_output_filename]

for checkpoint_filename, output_filename in zip(checkpoint_filenames, output_filenames):
    ## load checkpoint for model
    model.load_state_dict(
        torch.load(checkpoint_filename, map_location=torch.device("cpu"), weights_only=True)
    )
    model.to(args.device)

    hook_dict = obtain_hook_outputs(
        model,
        layer_names=[args.layer_name],
        dataset=dataset,
        tokenizer=tokenizer,
        device=args.device,
    )

    c = CategorySelectivity(dataset=dataset, hook_outputs=hook_dict)

    top_activating_categories = []

    for neuron_idx in tqdm(c.valid_neuron_indices):
        activations_for_idx = torch.zeros(len(c.categories))

        for class_idx, target_class in enumerate(c.categories):
            activation = c.konkle(
                neuron_idx=neuron_idx,
                target_class=target_class,
                other_classes=[x for x in dataset.keys() if x != target_class],
                layer_name=args.layer_name,
            )
            activations_for_idx[class_idx] = activation

        top_activating_categories.append(torch.argmax(activations_for_idx).item())

    fontsize = 18
    size = find_rectangle_dimensions(area=len(top_activating_categories))

    cmap = mcolors.ListedColormap(list(colors.values()))
    fig = plt.figure(figsize=(15, 8))

    im = np.array(top_activating_categories).reshape(size.height, size.width)
    ax = plt.imshow(im, cmap=cmap)
    plt.axis("off")

    if args.colorbar:
        colorbar = plt.colorbar(ax, ticks=np.arange(len(colors)))

    if args.colorbar_text:
        colorbar.set_ticklabels(list(colors.keys()))
    plt.tight_layout(pad=0)  # Remove extra padding
    fig.savefig(f"{output_filename}", bbox_inches='tight', pad_inches=0, dpi = 300)
    print(f"Saved: {output_filename}")
