import argparse
import torch
import os
from nesim.utils.json_stuff import load_json_as_dict
from tqdm import tqdm
import matplotlib.pyplot as plt
import numpy as np
from nesim.utils.grid_size import find_rectangle_dimensions
from nesim.utils.figure.figure_1 import exponential_colormap

from nesim.utils.figure.figure_1 import get_model_and_tokenizer, CategorySelectivity, obtain_hook_outputs


parser = argparse.ArgumentParser(
    description="Generate a map of top activating category for each neuron in cortical sheet"
)
parser.add_argument(
    "--checkpoint-filename",
    type=str,
    help="Path to the model checkpoint file",
    required=True,
)
parser.add_argument(
    "--dataset-filename",
    type=str,
    help="Path to the dataset json containing dataset and color info",
    required=True,
)
parser.add_argument(
    "--output-filename", type=str, help="filename of saved image", required=True
)
##
parser.add_argument(
    "--layer-name",
    type=str,
    help="name of layer to visualize",
    required=True,
    default="transformer.h.10.mlp.c_fc",
)
parser.add_argument(
    "--target-category", type=str, help="name of category from dataset", required=True
)

parser.add_argument("--device", type=str, default="cuda:0", help="device")

args = parser.parse_args()

assert os.path.exists(args.checkpoint_filename)
assert os.path.exists(args.dataset_filename)

data = load_json_as_dict(args.dataset_filename)
dataset = load_json_as_dict(args.dataset_filename)["dataset"]
colors = load_json_as_dict(args.dataset_filename)["colors"]

model, tokenizer = get_model_and_tokenizer(name="EleutherAI/gpt-neo-125m")

## load checkpoint for model
model.load_state_dict(
    torch.load(args.checkpoint_filename, map_location=torch.device("cpu"))
)
model.to(args.device)


hook_dict = obtain_hook_outputs(
    model,
    layer_names=[args.layer_name],
    dataset=dataset,
    tokenizer=tokenizer,
    device=args.device,
)

c = CategorySelectivity(dataset=dataset, hook_outputs=hook_dict)

threshold_value = 0.0
target_category_values = []

for neuron_idx in tqdm(c.valid_neuron_indices):
    activations_for_idx = torch.zeros(len(c.categories))

    activation = c.softmax_score(
        neuron_idx=neuron_idx,
        target_class=args.target_category,
        layer_name=args.layer_name,
    )

    if activation > threshold_value:
        target_category_values.append(activation)
    else:
        target_category_values.append(0)


fontsize = 18

fig = plt.figure(figsize=(15, 11))
fig.suptitle(
    f"Category selectivity map (softmax on mean activation per class)\nModel: gpt-neo-125m\nLayer: {args.layer_name}\nTarget category: {args.target_category}\nThreshold: {threshold_value}",
    fontsize=fontsize,
)

size = find_rectangle_dimensions(area=len(c.valid_neuron_indices))
im = np.array(target_category_values).reshape(size.height, size.width)

# EXP cmap
min_value = im.min()
max_value = im.max()
exponent = 2  # You can adjust this value to control the exponential effect

exp_cmap = exponential_colormap("viridis", exponent)
ax = plt.imshow(im, cmap=exp_cmap, vmin=min_value, vmax=max_value)
plt.colorbar()
fig.savefig(f"{args.output_filename}")


"""
example:
python3 generate_map_softmax.py --dataset-filename dataset.json \
 --checkpoint-filename ../../../training/gpt_neo_125m/checkpoints/apply_nesim_every_n_steps_10_nesim_config_scale_0.08_shrink_factor_[5.0]_layer_names_index_10_checkpoint_every_n_steps_10_num_warmup_steps_300_batch_size_128_context_length_256/checkpoint-670/pytorch_model.bin \
 --layer-name "transformer.h.10.mlp.c_fc" \
 --target-category "science" \
 --output-filename temp.png
"""
