import argparse
from functools import lru_cache
from pathlib import Path

from tqdm import tqdm

from EXP3.get_new_kns import select_kns
from KnowledgeSynapticNetwork.NeuroSynapticEditing import NeuroSynapticEdit
from KnowledgeSynapticNetwork.patch_mlp import *
from KnowledgeSynapticNetwork.utils import *
from EXP1.Case_visualize import draw_active_neuron_heatmaps
random.seed(42)


if __name__ == "__main__":
    parser = argparse.ArgumentParser()
    parser.add_argument(
        "--local-rank", help="local rank for multigpu processing", type=int, default=0
    )
    parser.add_argument(
        "--model_name_or_path",
        type=str,
        default='gpt2',
        # default='/netcache/huggingface/llama-7b'
        # default='/netcache/huggingface/Meta-Llama-3-8B-Instruct'
    )
    parser.add_argument('--data_path',
                        type=str,
                        default='Datasets/lama_formatted.json',
                        )
    parser.add_argument('--neurons_results_dir_to_save', type=str, default='ResDebug/0430')
    parser.add_argument('--other_neurons_operation', type=str, default='union')
    parser.add_argument('--ig_method', type=str, default='ig', choices=['ig', 'sig', 'amig'])
    parser.add_argument('--seed', type=int, default=42)
    parser.add_argument("--batch_size", type=int, default=10,)
    parser.add_argument("--steps",type=int,default=10,)
    parser.add_argument("--adaptive_threshold_neurons",type=float,default=0.2,)
    parser.add_argument("--adaptive_threshold_synapses",type=float,default=0.5,)
    parser.add_argument("--cas_threshold",type=float,default=0.3,)
    parser.add_argument("--beta1",type=float,default=0.7,)
    parser.add_argument("--beta2",type=float,default=0.3,)
    args = parser.parse_args()
    results_dir = Path(args.neurons_results_dir_to_save)
    os.makedirs(results_dir, exist_ok=True)
    
    random.seed(args.seed)
    dataset = read_lama_json(lama_json=args.data_path)
    gpus_number = torch.cuda.device_count()
    indices = list(range(len(dataset)))
    KEYS = list(dataset.keys())
    torch.cuda.set_device(args.local_rank)
    results = {}
    model_name = args.model_name_or_path
    if 'gpt2' in model_name:
        neurons_res_dir = f'{results_dir}/neurons_GPT2'
        indices = indices[args.local_rank: len(dataset): gpus_number]
    elif 'Llama-3-8B' in model_name:
        neurons_res_dir = f'{results_dir}/neurons_llama3-8b'
    elif 'llama-7b'  in model_name:
            neurons_res_dir = f'{results_dir}/neurons_llama2-7b'

    else:
        raise NotImplementedError
    os.makedirs(neurons_res_dir, exist_ok=True)

    n_s_model = NeuroSynapticEdit(model_name_or_path=model_name)

    @lru_cache(maxsize=None)
    # def get_neurons(_uuid):
    #     queries = dataset[_uuid]["sentences"]
    #     answer = dataset[_uuid]["obj_label"]
    #     # relation_name = dataset[_uuid]["relation_name"]
    #     neurons_3d = []
    #     for q in queries:
    #         n_2d = n_s_model.get_one_query_neurons(
    #             prompt=q,
    #             ground_truth=answer,
    #             batch_size=args.batch_size,steps=args.steps,adaptive_threshold_neurons=args.adaptive_threshold_neurons
    #         )
    #         neurons_3d.append(n_2d)
    #     return neurons_3d, dataset[_uuid]

    def process_data(_uuid, _unrelated_uuid=None, other_neurons_operation=None):
        queries = dataset[_uuid]["sentences"]
        if len(queries) <= 1:
            return
        answer = dataset[_uuid]["obj_label"]
        # query_unrelated = dataset[_unrelated_uuid]["sentences"][0]
        # answer_unrelated = dataset[_unrelated_uuid]["obj_label"]
        query_results = []  # List to hold results for each query
        other_neurons = []
        # probing_scores = []
        # queries_for_heatmap = []
        activation_values_all = []
        for idx_for_other_query, q in enumerate(queries):
            # random_idx = (idx + random.randint(1, len(queries) - 1)) % len(queries)
            # query_related = queries[random_idx]
            activation_values = n_s_model.get_scores(prompt=q,ground_truth=answer,batch_size=args.batch_size, steps=args.steps,)
            activation_values_all.append(activation_values)
            query_dict = {"query": q}
            n_2d, average_score = n_s_model.get_one_query_neurons(
                prompt=q,
                ground_truth=answer,
                batch_size=args.batch_size, steps=args.steps, adaptive_threshold_neurons=args.adaptive_threshold_neurons
            )
            attentions_one_query = n_s_model.get_attention_weights_for_one_query(query=q)
            token_one_query, idx_one_query = n_s_model.get_most_attended_token(attentions_one_query, q.split())

            # s_2d = n_s_model.get_knowledge_synapses_one_query(
            #     attention_matrices=attentions_one_query, query=q,
            #     adaptive_threshold_synapses=args.adaptive_threshold_synapses
            # )

            # Store neurons, synapses, and attention modification results in the query_dict
            query_dict["neurons"] = n_2d
            # query_dict['src_kn_score'] = average_score
            # query_dict["suppress_self_neuron"] =n_s_model.suppress_mlp(query=q, ground_truth=answer, neurons=n_2d)
            # query_dict["enhance_self_neuron"] =n_s_model.enhance_mlp(query=q, ground_truth=answer, neurons=n_2d)
            # query_dict["erase_knowledge"] = n_s_model.erase_knowledge(query=q, answer=answer,
            #                                                           neurons=n_2d, query_related=query_related,
            #                                                           query_unrelated=query_unrelated,
            #                                                           answer_unrelated=answer_unrelated)  # A,G,S, PPL

            # query_dict["synapses"] = s_2d
            # query_dict["suppress_attn"] = n_s_model.suppress_attn(query=q, answer=answer, synapses=s_2d)
            # query_dict["enhance_attn"] = n_s_model.enhance_attn(query=q, answer=answer, synapses=s_2d)
            query_dict["attended_token"] = token_one_query
            query_dict["attended_position"] = idx_one_query
            query_results.append(query_dict)

            # if idx_for_other_query >= 1:
            #     other_neurons.append(n_2d)

        q_for_other = queries[0]
        tuple_neurons = [set(tuple(subsublist) for subsublist in sublist) for sublist in other_neurons]
        if other_neurons_operation == "intersection":
            tmp_aggregated_neurons = set().intersection(*tuple_neurons)
        elif other_neurons_operation == "union":
            tmp_aggregated_neurons = set().union(*tuple_neurons)
        else:
            raise ValueError("Unsupported operation: choose 'intersection' or 'union'")
        aggregated_neurons = [list(neuron) for neuron in tmp_aggregated_neurons]
        selected_kns = select_kns(input_matrix=activation_values_all,beta1=args.beta1,beta2=args.beta2,
                                  threshold_factor=args.cas_threshold)


        # Calculate consistency ratio, position_static, and token_consistent across all queries
        # cr = n_s_model.calculate_consistency_ratio([q["neurons"] for q in query_results])
        position_static = len(set([q["attended_position"] for q in query_results])) == 1
        token_consistent = len(set([q["attended_token"] for q in query_results])) == 1



        result_this_uuid = {
            _uuid: {
                'query_results': query_results,
                'relation_name': dataset[_uuid]["relation_name"],
                # 'consistency_ratio': cr,
                'position_static': position_static,
                'token_consistent': token_consistent,
                # 'suppress_other_neurons': n_s_model.suppress_mlp(query=q_for_other, ground_truth=answer, neurons=aggregated_neurons),
                # 'enhance_other_neurons': n_s_model.enhance_mlp(query=q_for_other, ground_truth=answer, neurons=aggregated_neurons),
                'cas_suppress_neurons': n_s_model.suppress_mlp(query=q_for_other, ground_truth=answer, neurons=selected_kns),
                'cas_enhance_neurons': n_s_model.enhance_mlp(query=q_for_other, ground_truth=answer, neurons=selected_kns),
            }
        }

        results_json_path = f"{neurons_res_dir}/{args.local_rank}.jsonl"
        with open(results_json_path, "a") as res_f:
            res_f.write(json.dumps(result_this_uuid) + '\n')


    processed_uuids = set()
    result_files = glob.glob(os.path.join(neurons_res_dir, '*.jsonl'))
    # for jsonl
    for results_file in result_files:
        with open(results_file, 'r') as f:
            for line in f:
                result = json.loads(line)
                # Assuming each line is a JSON object with a UUID as its key
                uuid = next(iter(result))
                processed_uuids.add(uuid)

    # probing_scores = []
    # queries_for_heatmap = []
    for i, idx in enumerate(tqdm(indices, position=args.local_rank)):
        uuid = KEYS[idx]
        # random_uuid_idx = (idx + random.randint(1, len(KEYS) - 1)) % len(KEYS)
        # uuid_unrelated = KEYS[random_uuid_idx]
        if uuid not in processed_uuids:
            process_data(uuid,
                         # other_neurons_operation=args.other_neurons_operation, _unrelated_uuid=uuid_unrelated
                         )
