from easyeditor import (
    IKEHyperParams, 
    MEMITHyperParams, 
    ROMEHyperParams, 
    LoRAHyperParams,
    GraceHyperParams,
    )
from easyeditor import BaseEditor
import argparse
import numpy as np
import json
import os

def eval(result_path):
    if os.path.exists(result_path):
        
        with open(result_path,'r') as file:
            datas=json.load(file)
        #data_rome_counterfact['post'].keys()  dict_keys(['rewrite_acc', 'locality', 'portability'])
        Edit_Succ_list=[data_rome_counterfact['post']['rewrite_acc'][0] for data_rome_counterfact in datas]
        Edit_Succ=sum(Edit_Succ_list)/len(Edit_Succ_list)*100
        print(f'Edit_Succ: {Edit_Succ:.2f}')

        Generality_list=[data_rome_counterfact['post']['rephrase_acc'][0] for data_rome_counterfact in datas]
        Overall_generality=sum(Generality_list)/len(Generality_list)*100
        print(f'Overall_generality: {Overall_generality:.2f}')

        Locality_list=[]
        for data_rome_counterfact in datas:
            case_list=[]
            for key in data_rome_counterfact['post']['locality'].keys():
                case_list.append(sum(data_rome_counterfact['post']['locality'][key])/len(data_rome_counterfact['post']['locality'][key])*100)
            if len(case_list) != 0:
                Locality_list.append(np.mean(case_list))
        Overall_locality = np.mean(Locality_list)
        print(f'Overall_locality: {Overall_locality:.2f}')
        
        Instance_list = []
        Rule_list = []
        for data_rome_counterfact in datas:
            instance_list=[]
            rule_list=[]
            for key in data_rome_counterfact['post']['portability'].keys():
                if key == "Instance_acc":
                    instance_list.append(sum(data_rome_counterfact['post']['portability'][key]) / 
                                    len(data_rome_counterfact['post']['portability'][key]) * 100)
                if key == "Rule_acc":
                    rule_list.append(sum(data_rome_counterfact['post']['portability'][key]) / 
                                    len(data_rome_counterfact['post']['portability'][key]) * 100)
            if len(instance_list) != 0:
                Instance_list.append(np.mean(instance_list))
            if len(rule_list) != 0:
                Rule_list.append(np.mean(rule_list))
        Instance_portability = np.mean(Instance_list)
        Rule_understanding = np.mean(Rule_list)
        print(f'Instance_portability: {Instance_portability:.2f}')
        print(f'Rule_understanding:  {Rule_understanding:.2f}')
       

if __name__ == "__main__":
    parser = argparse.ArgumentParser()
    parser.add_argument('--editing_method', required=True, type=str)
    parser.add_argument('--hparams_dir', required=True, type=str)
    parser.add_argument('--data_dir', required=True, type=str)
    parser.add_argument('--metrics_save_dir', default='./output', type=str)
    parser.add_argument('--train_data_path', type=str)
    parser.add_argument('--rule_edit', default=True, type=str)
    parser.add_argument('--evaluation_type', default=None, type=str)
    parser.add_argument('--api_key', default=None, type=str)

    args = parser.parse_args()

    if args.editing_method == 'IKE':
        editing_hparams = IKEHyperParams
    elif args.editing_method == 'ICE':
        editing_hparams = IKEHyperParams
    elif args.editing_method == 'MEMIT':
        editing_hparams = MEMITHyperParams
    elif args.editing_method == 'ROME':
        editing_hparams = ROMEHyperParams
    elif args.editing_method == 'LoRA':
        editing_hparams = LoRAHyperParams
    elif args.editing_method == 'GRACE':
        editing_hparams = GraceHyperParams
    else:
        raise NotImplementedError
    
    hparams = editing_hparams.from_hparams(args.hparams_dir)
    hparams.rule_edit = args.rule_edit
    hparams.evaluation_type = args.evaluation_type
    hparams.api_key = args.api_key

    # === 1. 加载数据集 ===
    with open(args.data_dir, "r") as f:
        dataset = json.load(f)

    # === 2. 初始化编辑器 ===
    editor = BaseEditor.from_hparams(hparams)

    # 假设 dataset 是一个列表，每个元素是 dict，包含 prompt, subject 等字段
    prompts, subjects, ground_truths, target_news, rephrase_prompts = [], [], [], [], []
    neighborhood_prompts, neighborhood_truths = [], []
    distacting_prompts, distracting_truths = [], []
    instance_prompts, instance_truths = [], []
    rule_prompts, rule_truths = [], []

    for data in dataset:
        prompts.append(data["prompt"])
        subjects.append(data["subject"])
        ground_truths.append(data["ground_truth"] if "ground_truth" in data else "")
        target_news.append(data["target_new"])
        rephrase_prompts.append(data["rephrase_prompt"])

        neighborhood_prompts.append(data["locality"]["neighborhood"]["prompt"])
        neighborhood_truths.append(data["locality"]["neighborhood"]["ground_truth"])
        distacting_prompts.append(data["locality"]["distracting"]["prompt"])
        distracting_truths.append(data["locality"]["distracting"]["ground_truth"])

        instance_prompts.append(data["Instance"]["prompt"])
        instance_truths.append(data["Instance"]["target_new"])
        rule_prompts.append(data["Rule_Understanding"]["prompt"])
        rule_truths.append(data["Rule_Understanding"]["target_new"])
        

    # 组装 locality_inputs 和 portability_inputs
    locality_inputs = {
        'neighborhood': {
            'prompt': neighborhood_prompts,
            'ground_truth': neighborhood_truths
        },
        'distracting': {
            'prompt': distacting_prompts,
            'ground_truth': distracting_truths
        }
    }
    portability_inputs = {
        'Instance': {
            'prompt': instance_prompts,
            'ground_truth': instance_truths
        },
        'Rule': {
            'prompt': rule_prompts,
            'ground_truth': rule_truths
        }
    }

    # 只调用一次 edit
    metrics, edited_model, _ = editor.edit(
        prompts=prompts,
        subject=subjects,
        train_ds=None,
        ground_truth=ground_truths,
        target_new=target_news,
        rephrase_prompts=rephrase_prompts,
        locality_inputs=locality_inputs,
        portability_inputs=portability_inputs,
        # test_generation=True,
        sequential_edit=False,
        keep_original_weight=True
    )
    if not os.path.exists(args.metrics_save_dir):
        os.makedirs(args.metrics_save_dir)
    result_path = os.path.join(args.metrics_save_dir, f'{os.path.splitext(os.path.basename(args.data_dir))[0]}_{args.editing_method}_{hparams.model_name.split("/")[-1]}_results.json')
    json.dump(metrics, open(result_path, 'w'), indent=4)
    eval(result_path)

