from nesim.utils.ffn_sandbox import load_json_as_dict
from transformers import (
    AutoTokenizer,
)
from nesim.experiments.gpt_neo_125m import GPTNeoWorldKnowledgeConfig

from tqdm import tqdm
from nesim.utils.ffn_sandbox import dict_to_json, ResultFolder
import torch
import os
import argparse

parser = argparse.ArgumentParser(
    description="mention the start layer idx and the end layer idx"
)
parser.add_argument("-s", "--start", help="start", required=False)
parser.add_argument("-e", "--end", help="end", required=False)
args = vars(parser.parse_args())

print("start:", args["start"])
print("end:", args["end"])

config = GPTNeoWorldKnowledgeConfig.from_json("config.json")
tokenizer = AutoTokenizer.from_pretrained(config.model_name)
tokenizer.pad_token = tokenizer.eos_token


def get_top_activating_examples_for_neuron(
    all_activations,
    neuron_idx=20,
    num_top_dataset_samples=3,
):

    values, top_activating_dataset_indices = torch.topk(
        all_activations[:, neuron_idx], k=num_top_dataset_samples
    )

    top_activating_examples = []
    for idx in top_activating_dataset_indices.tolist():
        text = tokenizer.decode(results[idx]["input_ids"]).replace("\n", "")
        top_activating_examples.append(
            {"text": text, "activation": all_activations[idx, neuron_idx].item()}
        )
    return top_activating_examples


filtered_indices = load_json_as_dict(config.filtered_indices_filename)
results = ResultFolder(
    config.output_json_dir,
    num_items=None,  ## None = iter over all dataset items
    filtered_indices=filtered_indices,
)

"""
iterate over each layer idx and save top activating dataset samples for each neuron idx
"""


start = 0 if args["start"] is None else int(args["start"])
end = len(results[0]["ffn_activations"]) if args["end"] is None else int(args["end"])

for layer_idx in range(start, end):
    all_activations = []

    for filtered_dataset_idx in tqdm(
        range(len(results)), desc=f"[layer idx: {layer_idx}] collecting activations"
    ):
        item = results[filtered_dataset_idx]

        key_indices = item["ffn_activations"][layer_idx][
            "next_token_top_activating_key_indices"
        ]
        values = item["ffn_activations"][layer_idx]["next_token_top_activating_values"]

        actual_values_in_order = torch.zeros(len(key_indices))
        actual_values_in_order[key_indices] = torch.tensor(values)
        all_activations.append(actual_values_in_order.unsqueeze(0))

    all_activations = torch.cat(all_activations, dim=0)

    single_layer_neuron_samples = []

    """
    todos:
    fix the indexing here so that it works with a limited number of topk neuron indices
    """

    for neuron_idx in tqdm(
        range(config.num_neurons_in_target_layers),
        desc=f"[layer idx: {layer_idx}] collecting max activation samples",
    ):
        top_activating_examples = get_top_activating_examples_for_neuron(
            all_activations=all_activations,
            neuron_idx=neuron_idx,
            num_top_dataset_samples=config.num_top_activating_text_samples_per_neuron,
        )

        data = {
            "neuron_idx": neuron_idx,
            "top_activating_examples": top_activating_examples,
        }
        single_layer_neuron_samples.append(data)

    filename = os.path.join(
        config.top_activating_samples_json_folder, f"layer_idx_{layer_idx}.json"
    )
    dict_to_json(single_layer_neuron_samples, filename=filename)
    print(f"saved:", filename)
