import os
import os.path
import sys
import json
import random
import datasets
datasets.config.HF_DATASETS_OFFLINE = True

from easyeditor import (
    FTHyperParams,
    IKEHyperParams,
    KNHyperParams,
    MEMITHyperParams,
    ROMEHyperParams,
    LoRAHyperParams,
    MENDHyperParams,
    PMETHyperParams,
    SERACHparams,
    FINEHyperParams,
    EMMETHyperParams,
)
from easyeditor import BaseEditor
from easyeditor.models.ike import encode_ike_facts
from sentence_transformers import SentenceTransformer
from easyeditor import KnowEditDataset

import argparse

if __name__ == "__main__":
    parser = argparse.ArgumentParser()
    parser.add_argument('--editing_method', required=True, type=str)
    parser.add_argument('--hparams_dir', required=True, type=str)
    parser.add_argument('--data_dir', required=True, type=str)
    parser.add_argument('--ds_size', default=None, type=int)
    parser.add_argument('--metrics_save_dir', default='./output', type=str)
    parser.add_argument('--datatype', default=None, type=str)
    parser.add_argument('--train_data_path', type=str)
    parser.add_argument('--pre_file', default='./seq_pre.json', type=str)
    parser.add_argument('--random', action="store_true")

    args = parser.parse_args()

    if args.editing_method == 'FT':
        editing_hparams = FTHyperParams
    elif args.editing_method == 'IKE':
        editing_hparams = IKEHyperParams
    elif args.editing_method == 'KN':
        editing_hparams = KNHyperParams
    elif args.editing_method == 'MEMIT':
        editing_hparams = MEMITHyperParams
    elif args.editing_method == 'ROME':
        editing_hparams = ROMEHyperParams
    elif args.editing_method == 'MEND':
        editing_hparams = MENDHyperParams
    elif args.editing_method == 'LoRA':
        editing_hparams = LoRAHyperParams
    elif args.editing_method == 'FINE':
        editing_hparams = FINEHyperParams
    elif args.editing_method == 'PMET':
        editing_hparams = PMETHyperParams
    elif args.editing_method == 'EMMET':
        editing_hparams = EMMETHyperParams
    else:
        raise NotImplementedError

    if args.datatype == 'counterfact':
        args.data_dir = f"{args.data_dir}/wiki_counterfact/test_cf.json"
    elif args.datatype == 'zsre':
        args.data_dir = f"{args.data_dir}/ZsRE/ZsRE-test-all.json"
    elif args.datatype == 'recent':
        args.data_dir = f"{args.data_dir}/wiki_recent/recent_test.json"
    elif args.datatype == 'convsent':
        args.data_dir = f"{args.data_dir}/Convsent/blender_test.json"
    elif args.datatype == 'trivia':
        args.data_dir = f"{args.data_dir}/trivia/trivia_qa_test.json"
    else:
        raise NotImplementedError

    datas = KnowEditDataset(args.data_dir, size=args.ds_size)

    if args.datatype == 'counterfact' or args.datatype == 'recent' or args.datatype == 'zsre':
        prompts = [data['prompt'] for data in datas]
        subjects = [data['subject'] for data in datas]
        target_new = [data['target_new'] for data in datas]
        ground_truth = []
        for data in datas:
            if data['ground_truth'] is None:
                ground_truth.append('<|endoftext|>')
            else:
                if isinstance(data['ground_truth'], str):
                    ground_truth.append(data['ground_truth'])
                else:
                    ground_truth.append(data['ground_truth'][0])
            

        portability_r = [data['portability_r'] for data in datas]
        portability_s = [data['portability_s'] for data in datas]
        portability_l = [data['portability_l'] for data in datas]

        portability_reasoning_prompts = []
        portability_reasoning_ans = []
        portability_Logical_Generalization_prompts = []
        portability_Logical_Generalization_ans = []
        portability_Subject_Aliasing_prompts = []
        portability_Subject_Aliasing_ans = []

        portability_data = [portability_r, portability_s, portability_l]
        portability_prompts = [portability_reasoning_prompts, portability_Subject_Aliasing_prompts,
                               portability_Logical_Generalization_prompts]
        portability_answers = [portability_reasoning_ans, portability_Subject_Aliasing_ans,
                               portability_Logical_Generalization_ans]
        for data, portable_prompts, portable_answers in zip(portability_data, portability_prompts, portability_answers):
            for item in data:
                if item is None:
                    portable_prompts.append(None)
                    portable_answers.append(None)
                else:
                    temp_prompts = []
                    temp_answers = []
                    for pr in item:
                        prompt = pr["prompt"]
                        an = pr["ground_truth"]
                        while isinstance(an, list):
                            an = an[0]
                        if an.strip() == "":
                            continue
                        temp_prompts.append(prompt)
                        temp_answers.append(an)
                    portable_prompts.append(temp_prompts)
                    portable_answers.append(temp_answers)
        assert len(prompts) == len(portability_reasoning_prompts) == len(
            portability_Logical_Generalization_prompts) == len(portability_Subject_Aliasing_prompts)

        locality_rs = [data['locality_rs'] for data in datas]
        locality_f = [data['locality_f'] for data in datas]
        locality_Relation_Specificity_prompts = []
        locality_Relation_Specificity_ans = []
        locality_Forgetfulness_prompts = []
        locality_Forgetfulness_ans = []

        locality_data = [locality_rs, locality_f]
        locality_prompts = [locality_Relation_Specificity_prompts, locality_Forgetfulness_prompts]
        locality_answers = [locality_Relation_Specificity_ans, locality_Forgetfulness_ans]
        for data, local_prompts, local_answers in zip(locality_data, locality_prompts, locality_answers):
            for item in data:
                if item is None:
                    local_prompts.append(None)
                    local_answers.append(None)
                else:
                    temp_prompts = []
                    temp_answers = []
                    for pr in item:
                        prompt = pr["prompt"]
                        an = pr["ground_truth"]
                        while isinstance(an, list):
                            an = an[0]
                        if an.strip() == "":
                            continue
                        temp_prompts.append(prompt)
                        temp_answers.append(an)
                    local_prompts.append(temp_prompts)
                    local_answers.append(temp_answers)
        assert len(prompts) == len(locality_Relation_Specificity_prompts) == len(locality_Forgetfulness_prompts)
        locality_inputs = {}
        portability_inputs = {}

        locality_inputs = {
            'Relation_Specificity': {
                'prompt': locality_Relation_Specificity_prompts,
                'ground_truth': locality_Relation_Specificity_ans
            },
            'Forgetfulness': {
                'prompt': locality_Forgetfulness_prompts,
                'ground_truth': locality_Forgetfulness_ans
            }
        }
        portability_inputs = {
            'Subject_Aliasing': {
                'prompt': portability_Subject_Aliasing_prompts,
                'ground_truth': portability_Subject_Aliasing_ans
            },
            'reasoning': {
                'prompt': portability_reasoning_prompts,
                'ground_truth': portability_reasoning_ans
            },
            'Logical_Generalization': {
                'prompt': portability_Logical_Generalization_prompts,
                'ground_truth': portability_Logical_Generalization_ans
            }
        }
    if args.datatype == 'wikibio':
        prompts = [data['prompt'] for data in datas]
        subjects = [data['subject'] for data in datas]
        target_new = [data['target_new'] for data in datas]

        locality_rs = [data['locality_rs'] for data in datas]
        locality_f = [data['locality_f'] for data in datas]
        locality_Relation_Specificity_prompts = []
        locality_Relation_Specificity_ans = []

        locality_data = [locality_rs]
        locality_prompts = [locality_Relation_Specificity_prompts]
        locality_answers = [locality_Relation_Specificity_ans]
        for data, local_prompts, local_answers in zip(locality_data, locality_prompts, locality_answers):
            for item in data:
                if item is None:
                    local_prompts.append(None)
                    local_answers.append(None)
                else:
                    temp_prompts = []
                    temp_answers = []
                    for pr in item:
                        prompt = pr["prompt"]
                        an = pr["ground_truth"]
                        while isinstance(an, list):
                            an = an[0]
                        if an.strip() == "":
                            continue
                        temp_prompts.append(prompt)
                        temp_answers.append(an)
                    local_prompts.append(temp_prompts)
                    local_answers.append(temp_answers)
        assert len(prompts) == len(locality_Relation_Specificity_prompts)
        portability_inputs = None
        locality_inputs = {}
        locality_inputs = {
            'Relation_Specificity': {
                'prompt': locality_Relation_Specificity_prompts,
                'ground_truth': locality_Relation_Specificity_ans
            }
        }

    hparams = editing_hparams.from_hparams(args.hparams_dir)

    if args.editing_method == 'FT':
        if type(hparams.norm_constraint) is float:
            args.editing_method = 'FT-L'
    
    print(args.editing_method)

    if not os.path.exists(args.metrics_save_dir):
        os.makedirs(args.metrics_save_dir)
    args.pre_file = f"{args.metrics_save_dir}/{hparams.model_name.split('/')[-1]}_{args.datatype}_pre_edit.json"
    if args.pre_file is not None and os.path.exists(args.pre_file):
        print(args.pre_file)
        pre_edit = json.load(open(args.pre_file, 'r'))[:args.ds_size]
        assert len(pre_edit) == len(prompts)
    else:
        pre_edit = None
    if args.editing_method == 'IKE':
        train_ds = KnowEditDataset(args.train_data_path)
        sentence_model = SentenceTransformer(hparams.sentence_model_name).to(f'cuda:{hparams.device}')
        encode_ike_facts(sentence_model, train_ds, hparams)
    else:
        train_ds = None
    editor = BaseEditor.from_hparams(hparams)
    metrics, edited_model, _ = editor.edit(
        prompts=prompts,
        target_new=target_new,
        subject=subjects,
        ground_truth=ground_truth,
        locality_inputs=locality_inputs,
        portability_inputs=portability_inputs,
        train_ds=train_ds,
        keep_original_weight=True,
        pre_file=args.pre_file,
        pre_edit=pre_edit,
        test_generation=True,
    )
    if args.editing_method == "FINE":
        json.dump(metrics, open(os.path.join(args.metrics_save_dir,
                                             f'{args.editing_method}_neuron_{hparams.neuron_num}_{args.datatype}_{hparams.model_name.split("/")[-1]}_Layer_{hparams.layer}_results.json'),
                                'w'), indent=4)
    else:
        json.dump(metrics, open(os.path.join(args.metrics_save_dir,
                                             f'{args.editing_method}_{args.datatype}_{hparams.model_name.split("/")[-1]}_results.json'),
                                'w'), indent=4)