import argparse
import os
import random
from functools import lru_cache
from pathlib import Path

from tqdm import tqdm
from transformers import AutoTokenizer, AutoModelForCausalLM

from KnowledgeSynapticNetwork.NeuroSynapticEditing import NeuroSynapticEdit
from KnowledgeSynapticNetwork.patch_mlp import *
from KnowledgeSynapticNetwork.utils import *
from EXP1.Case_visualize import draw_active_neuron_heatmaps

random.seed(42)

if __name__ == "__main__":

    # tokenizer = AutoTokenizer.from_pretrained("NousResearch/Meta-Llama-3-8B-Instruct")
    # model = AutoModelForCausalLM.from_pretrained("NousResearch/Meta-Llama-3-8B-Instruct")

    parser = argparse.ArgumentParser()
    parser.add_argument(
        "--local-rank", help="local rank for multigpu processing", type=int, default=0
    )
    parser.add_argument(
        "--model_name_or_path",
        type=str,
        default='gpt2',
        # default='/netcache/huggingface/llama-7b'
        # default='/netcache/huggingface/Meta-Llama-3-8B-Instruct'
    )
    parser.add_argument('--data_path',
                        type=str,
                        default='Datasets/lama_formatted.json',
                        )
    parser.add_argument('--neurons_results_dir_to_save', type=str, default='ResDebug/0416')
    parser.add_argument('--other_neurons_operation', type=str, default='union')
    parser.add_argument('--seed', type=int, default=42)
    parser.add_argument("--batch_size", type=int, default=10, )
    parser.add_argument("--steps", type=int, default=10, )
    parser.add_argument("--adaptive_threshold_neurons", type=float, default=0.2, )
    parser.add_argument("--adaptive_threshold_synapses", type=float, default=0.5, )
    args = parser.parse_args()
    results_dir = Path(args.neurons_results_dir_to_save)
    os.makedirs(results_dir, exist_ok=True)

    random.seed(args.seed)
    dataset = read_lama_json(lama_json=args.data_path)
    gpus_number = torch.cuda.device_count()
    indices = list(range(len(dataset)))
    KEYS = list(dataset.keys())
    torch.cuda.set_device(args.local_rank)
    results = {}
    model_name = args.model_name_or_path
    if 'gpt2' in model_name:
        neurons_res_dir = f'{results_dir}/neurons_GPT2'
        indices = indices[args.local_rank: len(dataset): gpus_number]
    elif 'Llama-3-8B' in model_name:
        neurons_res_dir = f'{results_dir}/neurons_llama3-8b'
    elif 'llama-7b'  in model_name:
            neurons_res_dir = f'{results_dir}/neurons_llama2-7b'
    else:
        raise NotImplementedError
    os.makedirs(neurons_res_dir, exist_ok=True)

    n_s_model = NeuroSynapticEdit(model_name_or_path=model_name)


    @lru_cache(maxsize=None)
    # def get_neurons(_uuid):
    #     queries = dataset[_uuid]["sentences"]
    #     answer = dataset[_uuid]["obj_label"]
    #     # relation_name = dataset[_uuid]["relation_name"]
    #     neurons_3d = []
    #     for q in queries:
    #         n_2d = n_s_model.get_one_query_neurons(
    #             prompt=q,
    #             ground_truth=answer,
    #             batch_size=args.batch_size,steps=args.steps,adaptive_threshold_neurons=args.adaptive_threshold_neurons
    #         )
    #         neurons_3d.append(n_2d)
    #     return neurons_3d, dataset[_uuid]

    def process_data(_uuid,):
        queries = dataset[_uuid]["sentences"]
        answer = dataset[_uuid]["obj_label"]
        for idx_for_other_query, q in enumerate(queries):
            # tmp: heapmap
            attribution_scores = n_s_model.get_scores(prompt=q, ground_truth=answer, batch_size=args.batch_size,
                                                      steps=args.steps)

            # draw_active_neuron_heatmaps(probing_scores=probing_scores, chunk_method='max', queries=queries_for_heatmap)
            if 'Suleiman' in queries[0]:
                if idx_for_other_query == 1 or idx_for_other_query == 4:
                    probing_scores.append(attribution_scores)
                    queries_for_heatmap.append(queries[idx_for_other_query])

            if 'Christoph' in queries[0]:
                if idx_for_other_query == 1 or idx_for_other_query == 2:
                    probing_scores.append(attribution_scores)
                    queries_for_heatmap.append(queries[idx_for_other_query])


    probing_scores = []
    queries_for_heatmap = []

    for i, idx in enumerate(tqdm(indices, position=args.local_rank)):
        uuid = KEYS[idx]
        # process_data(uuid, _unrelated_uuid=uuid_unrelated)
        if uuid == '4e1a8cfa-72ee-4ff6-b3a5-8fbba183dfa2' or uuid == '51179628-b369-43b6-a74b-fbb5f7e9fb49':
            process_data(uuid)
        # if uuid == '4e1a8cfa-72ee-4ff6-b3a5-8fbba183dfa2':
        #     process_data(uuid)
    heatmap_res_dir = '/home/chenyuheng/chenyuheng/NIPS2024/EXP1/Res/heatmap_with_vmax'
    os.makedirs(heatmap_res_dir, exist_ok=True)
    draw_active_neuron_heatmaps(probing_scores=probing_scores, chunk_method='max', queries=queries_for_heatmap,
                                save_filename=f'{heatmap_res_dir}/heatmap_GPT2.pdf',
                                vmax1=1.6e-7,
                                vmax2=3e-6
                                )
    # Process each tensor in the probing_scores list


    # draw_active_neuron_heatmaps(probing_scores=probing_scores_suppress,queries=None, chunk_method='max',
    #                             save_filename=f'{heatmap_res_dir}/heatmap_gpt2_suppress_attn.pdf',
    #                             vmax1=1.6e-7,
    #                             vmax2=3e-6
    #                             )
    # draw_active_neuron_heatmaps(probing_scores=probing_scores_suppress,queries=queries_for_heatmap, chunk_method='max',
    #                             save_filename=f'{heatmap_res_dir}/heatmap_gpt2_suppress_attn_with_q.pdf',
    #                             vmax1=1.6e-7,
    #                             vmax2=3e-6
    #                             )
    # draw_active_neuron_heatmaps(probing_scores=probing_scores_enhance,queries=queries_for_heatmap, chunk_method='max',
    #                             save_filename=f'{heatmap_res_dir}/heatmap_gpt2_enhance_attn_with_q.pdf',
    #                             vmax1=1.6e-7,
    #                             vmax2=3e-6
    #                             )