from easyeditor import (
    IKEHyperParams, 
    MEMITHyperParams, 
    ROMEHyperParams, 
    LoRAHyperParams,
    GraceHyperParams,
    )
from easyeditor import BaseEditor
import argparse
import numpy as np
import json
import os

def eval_hierarchical(result_path):
    if os.path.exists(result_path):
        
        with open(result_path,'r') as file:
            datas=json.load(file)
        #data_rome_counterfact['post'].keys()  dict_keys(['rewrite_acc', 'locality', 'portability'])
        Edit_Succ_list=[data_rome_counterfact['post']['rewrite_acc'][0] for data_rome_counterfact in datas]
        Edit_Succ=sum(Edit_Succ_list)/len(Edit_Succ_list)*100
        print(f'Edit_Succ: {Edit_Succ:.2f}')

        Hierarchical_list = []
        for data_rome_counterfact in datas:
            hierarchical_list=[]
            for key in data_rome_counterfact['post']['portability'].keys():
                if key == "Hierarchical_acc":
                    hierarchical_list.append(sum(data_rome_counterfact['post']['portability'][key]) / 
                                    len(data_rome_counterfact['post']['portability'][key]) * 100)
            if len(hierarchical_list) != 0:
                Hierarchical_list.append(np.mean(hierarchical_list))
        Hierarchical_portability = np.mean(Hierarchical_list)
        print(f'Hierarchical_portability:  {Hierarchical_portability:.2f}')
       

if __name__ == "__main__":
    parser = argparse.ArgumentParser()
    parser.add_argument('--editing_method', required=True, type=str)
    parser.add_argument('--hparams_dir', required=True, type=str)
    parser.add_argument('--data_dir', required=True, type=str)
    parser.add_argument('--metrics_save_dir', default='./output', type=str)
    parser.add_argument('--train_data_path', type=str)
    parser.add_argument('--rule_edit', default=True, type=str)
    parser.add_argument('--evaluation_type', default=None, type=str)
    parser.add_argument('--api_key', default=None, type=str)

    args = parser.parse_args()

    if args.editing_method == 'IKE':
        editing_hparams = IKEHyperParams
    elif args.editing_method == 'ICE':
        editing_hparams = IKEHyperParams
    elif args.editing_method == 'MEMIT':
        editing_hparams = MEMITHyperParams
    elif args.editing_method == 'ROME':
        editing_hparams = ROMEHyperParams
    elif args.editing_method == 'LoRA':
        editing_hparams = LoRAHyperParams
    elif args.editing_method == 'GRACE':
        editing_hparams = GraceHyperParams
    else:
        raise NotImplementedError
    
    hparams = editing_hparams.from_hparams(args.hparams_dir)
    hparams.rule_edit = args.rule_edit
    hparams.evaluation_type = args.evaluation_type
    hparams.api_key = args.api_key

    # === 1. 加载数据集 ===
    with open(args.data_dir, "r") as f:
        dataset = json.load(f)

    # === 2. 初始化编辑器 ===
    editor = BaseEditor.from_hparams(hparams)

    # 假设 dataset 是一个列表，每个元素是 dict，包含 prompt, subject 等字段
    prompts, subjects, ground_truths, target_news = [], [], [], []
    hierarchical_prompts, hierarchical_truths = [], []

    for data in dataset:
        prompts.append(data["prompt"])
        subjects.append(data["subject"])
        ground_truths.append(data["ground_truth"] if "ground_truth" in data else "")
        target_news.append(data["target_new"])

        hierarchical_prompts.append(data["hierarchical"]["prompt"])
        hierarchical_truths.append(data["hierarchical"]["target_new"])

    # 组装 locality_inputs 和 portability_inputs
    portability_inputs = {
        'Hierarchical': {
            'prompt': hierarchical_prompts,
            'ground_truth': hierarchical_truths
        }
    }
    # 只调用一次 edit
    metrics, edited_model, _ = editor.edit(
        prompts=prompts,
        subject=subjects,
        train_ds=None,
        ground_truth=ground_truths,
        target_new=target_news,
        portability_inputs=portability_inputs,
        # test_generation=True,
        sequential_edit=False,
        keep_original_weight=True
    )
    if not os.path.exists(args.metrics_save_dir):
        os.makedirs(args.metrics_save_dir)
    result_path = os.path.join(args.metrics_save_dir, f'{os.path.splitext(os.path.basename(args.data_dir))[0]}_{args.editing_method}_{hparams.model_name.split("/")[-1]}_results.json')
    json.dump(metrics, open(result_path, 'w'), indent=4)
    eval_hierarchical(result_path)

