import os.path
import json
import argparse
import torch

from easyeditor import (
    FTHyperParams,
    MENDHyperParams,
    ROMEHyperParams,
    R_ROMEHyperParams,
    MEMITHyperParams,
    GraceHyperParams,
    WISEHyperParams,
    AlphaEditHyperParams,
    IKEHyperParams,
    MELOHyperParams,
    LoRAHyperParams,
    RoseLoRAHyperParams,
    BaseEditor,
    summary_metrics,
    ZsreDataset,
    CounterFactDataset
)
from easyeditor.models.ike import encode_ike_facts

if __name__ == "__main__":
    parser = argparse.ArgumentParser()
    parser.add_argument('--editing_method', required=True, type=str)
    parser.add_argument('--hparams_dir', required=True, type=str)
    parser.add_argument('--data_path', required=True, type=str)
    parser.add_argument('--datatype', default=True,type=str)
    parser.add_argument('--ds_size', default=10, type=int)
    parser.add_argument('--sequential_edit', default='False', type=str)
    parser.add_argument('--batch_edit', default='False', type=str)
    parser.add_argument('--evaluation_type', default='LLM-judge', type=str)
    parser.add_argument('--device', default=0, type=int)
    parser.add_argument('--down_eval', default='False', type=str)
    parser.add_argument('--save_model_dir', default="./edited_models/default", type=str)

    args = parser.parse_args()

    if args.sequential_edit == "True" or args.sequential_edit == "true":
        sequential_edit = True
    else:
        sequential_edit = False

    if args.batch_edit == "True" or args.batch_edit == "true":
        batch_edit = True
    else:
        batch_edit = False

    if args.down_eval == "True" or args.down_eval == "true":
        down_eval = True
    else:
        down_eval = False

    if args.editing_method == 'FT':
        editing_hparams = FTHyperParams
    elif args.editing_method == 'MEND':
        editing_hparams = MENDHyperParams
    elif args.editing_method == 'ROME':
        editing_hparams = ROMEHyperParams
    elif args.editing_method == 'R-ROME':
        editing_hparams = R_ROMEHyperParams
    elif args.editing_method == 'MEMIT':
        editing_hparams = MEMITHyperParams
    elif args.editing_method == 'GRACE':
        editing_hparams = GraceHyperParams
    elif args.editing_method == 'WISE':
        editing_hparams = WISEHyperParams
    elif args.editing_method == 'AlphaEdit':
        editing_hparams = AlphaEditHyperParams
    elif args.editing_method == 'IKE':
        editing_hparams = IKEHyperParams
    elif args.editing_method == 'MELO':
        editing_hparams = MELOHyperParams
    elif args.editing_method == 'LoRA':
        editing_hparams = LoRAHyperParams
    elif args.editing_method == 'RoseLoRA':
        editing_hparams = RoseLoRAHyperParams
    else:
        raise NotImplementedError

    # load and process data
    K = args.ds_size
    data_path = args.data_path
    edit_data = json.load(open(data_path, 'r', encoding='utf-8'))[:K]

    if args.datatype == 'counterfact':
        prompts = [edit_data_['prompt'] for edit_data_ in edit_data]
        subject = [edit_data_['subject'] for edit_data_ in edit_data]
        rephrase_prompts = [edit_data_['rephrase_prompt'] for edit_data_ in edit_data]
        target_new = [edit_data_['target_new'] for edit_data_ in edit_data]
        locality_prompts = [edit_data_['locality_prompt'] for edit_data_ in edit_data]
        locality_ans = [edit_data_['locality_ground_truth'] for edit_data_ in edit_data]
    elif args.datatype == 'zsre':
        prompts = [edit_data_['src'] for edit_data_ in edit_data]
        subject = [edit_data_['subject'] for edit_data_ in edit_data]
        rephrase_prompts = [edit_data_['rephrase'] for edit_data_ in edit_data]
        target_new = [edit_data_['alt'] for edit_data_ in edit_data]
        locality_prompts = [edit_data_['loc'] for edit_data_ in edit_data]
        locality_ans = [edit_data_['loc_ans'] for edit_data_ in edit_data]
    elif args.datatype == 'qaedit':
        prompts = [edit_data_['prompt'] for edit_data_ in edit_data]
        subject = [edit_data_['subject'] for edit_data_ in edit_data]
        rephrase_prompts = [edit_data_['rephrase'] for edit_data_ in edit_data]
        target_new = [edit_data_['target'] for edit_data_ in edit_data]
        locality_prompts = [edit_data_["locality"][0]["loc"] for edit_data_ in edit_data]
        locality_ans = [edit_data_["locality"][0]["loc_ans"] for edit_data_ in edit_data]

    locality_inputs = {
        'neighborhood':{
            'prompt': locality_prompts,
            'ground_truth': locality_ans
        },
    }

    hparams = editing_hparams.from_hparams(args.hparams_dir)
    hparams.evaluation_type = args.evaluation_type
    hparams.data_path = args.data_path
    hparams.device = args.device

    editor = BaseEditor.from_hparams(hparams)

    if batch_edit:
        metrics, edited_model, _, tokenizer = editor.batch_edit(
            prompts=prompts,
            rephrase_prompts=rephrase_prompts,
            subject=subject,
            target_new=target_new,
            locality_inputs=locality_inputs,
        )
    else:
        metrics, edited_model, _, tokenizer = editor.edit(
            prompts=prompts,
            rephrase_prompts=rephrase_prompts,
            subject=subject,
            target_new=target_new,
            locality_inputs=locality_inputs,
            sequential_edit=sequential_edit,
        )

    # downsteam tasks evaluation or not
    if down_eval:
        save_directory = args.save_model_dir
        edited_model.save_pretrained(save_directory)
        tokenizer.save_pretrained(save_directory)

    if len(metrics) > 0:
        summary_metrics(metrics)