import argparse
import torch
import os
from nesim.utils.json_stuff import load_json_as_dict
from tqdm import tqdm
import matplotlib.pyplot as plt
import numpy as np
from nesim.utils.grid_size import find_rectangle_dimensions
from nesim.utils.figure.figure_1 import exponential_colormap

from nesim.utils.figure.figure_1 import get_model_and_tokenizer, CategorySelectivity, obtain_hook_outputs

from safetensors.torch import load_file


parser = argparse.ArgumentParser(
    description="Generate a map of top activating category for each neuron in cortical sheet"
)
parser.add_argument(
    "--checkpoint-folder",
    type=str,
    help="Path to the model checkpoint file. It can contain either a pytorch_model.bin file or a model.safetensors file.",
    required=True,
)
parser.add_argument(
    "--dataset-filename",
    type=str,
    help="Path to the dataset json containing dataset and color info",
    required=True,
)
parser.add_argument(
    "--output-filename", type=str, help="filename of saved image", required=True
)
parser.add_argument(
    "--output-filename-numpy",
    type=str,
    default=None,
    help="filename of saved numpy array",
)

parser.add_argument(
    "--layer-name",
    type=str,
    help="name of layer to visualize",
    required=True,
    default="transformer.h.10.mlp.c_fc",
)
parser.add_argument(
    "--target-category", type=str, help="name of category from dataset", required=True
)

parser.add_argument("--device", type=str, default="cuda:0", help="device")

args = parser.parse_args()

assert os.path.exists(args.checkpoint_folder)
assert os.path.exists(args.dataset_filename)

data = load_json_as_dict(args.dataset_filename)
dataset_with_sub_categories = load_json_as_dict(args.dataset_filename)["dataset"]
colors = load_json_as_dict(args.dataset_filename)["colors"]

full_categories: list = []
dataset = {}
for category in dataset_with_sub_categories:

    for sub_category in dataset_with_sub_categories[category]:
        dataset[f"{category}.{sub_category}"] = []

    for sub_category in dataset_with_sub_categories[category]:
        dataset[f"{category}.{sub_category}"].extend(
            dataset_with_sub_categories[category][sub_category]
        )

model, tokenizer = get_model_and_tokenizer(name="EleutherAI/gpt-neo-125m")

## load checkpoint for model

if os.path.exists(os.path.join(args.checkpoint_folder, "pytorch_model.bin")):
    model.load_state_dict(
        torch.load(
            os.path.join(args.checkpoint_folder, "pytorch_model.bin"), 
            map_location=torch.device("cpu")
        )
    )
else:
    raise ValueError(f"Unsupported checkpoint format: {args.checkpoint_folder}")
    
model.to(args.device)


hook_dict = obtain_hook_outputs(
    model,
    layer_names=[args.layer_name],
    dataset=dataset,
    tokenizer=tokenizer,
    device=args.device,
)

c = CategorySelectivity(dataset=dataset, hook_outputs=hook_dict)

threshold_value = 0.0
target_category_values = []

for neuron_idx in tqdm(c.valid_neuron_indices):
    activations_for_idx = torch.zeros(len(c.categories))

    activation = c.softmax_score(
        neuron_idx=neuron_idx,
        target_class=args.target_category,
        layer_name=args.layer_name,
    )

    if activation > threshold_value:
        target_category_values.append(activation)
    else:
        target_category_values.append(0)


fontsize = 18

fig = plt.figure(figsize=(15, 11))
fig.suptitle(
    f"Category selectivity map (softmax on mean activation per class)\nModel: gpt-neo-125m\nLayer: {args.layer_name}\nTarget category: {args.target_category}\nThreshold: {threshold_value}",
    fontsize=fontsize,
)

size = find_rectangle_dimensions(area=len(c.valid_neuron_indices))
im = np.array(target_category_values).reshape(size.height, size.width)

# EXP cmap
min_value = im.min()
max_value = im.max()
exponent = 2  # You can adjust this value to control the exponential effect

exp_cmap = exponential_colormap("viridis", exponent)
ax = plt.imshow(im, cmap=exp_cmap, vmin=min_value, vmax=max_value)
plt.colorbar()
fig.savefig(f"{args.output_filename}")
print(f"Saved: {args.output_filename}")

if args.output_filename_numpy is not None:
    np.save(args.output_filename_numpy, im)
    print(f"Saved: {args.output_filename_numpy}")
