from easyeditor import (
    IKEHyperParams, 
    MEMITHyperParams, 
    ROMEHyperParams, 
    LoRAHyperParams,
    GraceHyperParams,
    )
from easyeditor import BaseEditor
import argparse
import numpy as np
import json
import os
from easyeditor.evaluate.evaluate import compute_rewrite_or_rephrase_quality, compute_locality_quality, compute_portability_quality
from tqdm import tqdm

def eval_instance(result_path):
    if os.path.exists(result_path):
        
        with open(result_path,'r') as file:
            datas=json.load(file)
        #data_rome_counterfact['post'].keys()  dict_keys(['rewrite_acc', 'locality', 'portability'])
        Edit_Succ_list=[data_rome_counterfact['post']['rewrite_acc'][0] for data_rome_counterfact in datas]
        Edit_Succ=sum(Edit_Succ_list)/len(Edit_Succ_list)*100
        print(f'Edit_Succ: {Edit_Succ:.2f}')

        Generality_list = []
        for data_rome_counterfact in datas:
            generality_list = []
            generality_list.append(sum(data_rome_counterfact['post']['rephrase_acc']) / len(data_rome_counterfact['post']['rephrase_acc']) * 100)
            if len(generality_list) != 0:
                Generality_list.append(np.mean(generality_list))
        Overall_generality=np.mean(Generality_list)
        print(f'Overall_generality: {Overall_generality:.2f}')

        Locality_list=[]
        for data_rome_counterfact in datas:
            case_list=[]
            for key in data_rome_counterfact['post']['locality'].keys():
                case_list.append(sum(data_rome_counterfact['post']['locality'][key])/len(data_rome_counterfact['post']['locality'][key])*100)
            if len(case_list) != 0:
                Locality_list.append(np.mean(case_list))
        Overall_locality = np.mean(Locality_list)
        print(f'Overall_locality: {Overall_locality:.2f}')
        
        Instance_list = []
        Rule_list1 = []
        Rule_list2 = []
        for data_rome_counterfact in datas:
            instance_list=[]
            for key in data_rome_counterfact['post']['portability'].keys():
                if key == "Instance_acc":
                    instance_list.append(sum(data_rome_counterfact['post']['portability'][key]) / 
                                    len(data_rome_counterfact['post']['portability'][key]) * 100)
                if key == "Rule_acc":
                    Rule_list1.append(data_rome_counterfact['post']['portability'][key][0] * 100)
                    Rule_list2.append(data_rome_counterfact['post']['portability'][key][1] * 100)
            if len(instance_list) != 0:
                Instance_list.append(np.mean(instance_list))
        Instance_portability = np.mean(Instance_list)
        Rule_understanding_f = np.mean(Rule_list1)
        Rule_understanding_d = np.mean(Rule_list2)
        Rule_understanding = (Rule_understanding_f + Rule_understanding_d) / 2
        print(f'Instance_portability: {Instance_portability:.2f}')
        print(f'Rule_understanding:  {Rule_understanding:.2f}')
        print(f'Rule_understanding_f:  {Rule_understanding_f:.2f}')
        print(f'Rule_understanding_d:  {Rule_understanding_d:.2f}')


        

def process_single_data(data, hparams):
    # === 初始化编辑器 ===
    editor = BaseEditor.from_hparams(hparams)
    # 1. 编辑
    prompts, subjects, target_news = data["prompt"], [data["subject"]] * len(data["prompt"]), data["target_new"]

    pre_loc = {}
    pre_loc['locality'] = {} 
    if 'locality' in data.keys() and any(data['locality']):
        for locality_key in data['locality'].keys():
            pre_loc['locality'].update(
                compute_locality_quality(editor.model, editor.model_name, hparams, editor.tok, locality_key, data['locality'][locality_key]['prompt'], data['locality'][locality_key]['ground_truth'], device=hparams.device)
            )

    metrics, edited_model, _ = editor.edit(
        prompts=prompts,
        ground_truth=None,
        target_new= target_news,
        subject=subjects,
        sequential_edit=True  # True: start continuous editing
    )
    # print("Edit metrics:", metrics)

    # 2. 测试
    rephrase_prompts = data["rephrase_prompt"]
    
    ret = compute_rewrite_or_rephrase_quality(edited_model, editor.model_name, hparams, editor.tok,rephrase_prompts, target_news, device=hparams.device, test_rephrase=True, eval_metric='token_em')
    ret['locality'] = {}
    ret['portability'] = {}
    if 'locality' in data.keys() and any(data['locality']):
        for locality_key in data['locality'].keys():
            ret['locality'].update(
                compute_locality_quality(edited_model, editor.model_name, hparams, editor.tok, locality_key, data['locality'][locality_key]['prompt'], data['locality'][locality_key]['ground_truth'], device=hparams.device)
            )
    if 'Rule_Understanding' in data.keys():
        ret['portability'].update(
            compute_portability_quality(edited_model, editor.model_name, hparams, editor.tok, "Rule", data["Rule_Understanding"]['prompt'], data["Rule_Understanding"]["target_new"], device=hparams.device)
        )
    if 'Rule_Understanding' in data.keys():
        ret['portability'].update(
            compute_portability_quality(edited_model, editor.model_name, hparams, editor.tok, "Instance", data["Instance"]['prompt'], data["Instance"]["target_new"], device=hparams.device)
        )
    merged_metrics = {
        'pre': {
            'rewrite_acc': [sum(item['pre']['rewrite_acc'][0] for item in metrics) / len(metrics) if metrics else 0],
            'rewrite_gen_content': [item['pre']['rewrite_gen_content'][0] for item in metrics],
            'locality': {}
        },
        'case_id': id,
        'requested_rewrite': {
            'prompt': [item['requested_rewrite']['prompt'] for item in metrics],
            'target_new': [item['requested_rewrite']['target_new'] for item in metrics],
            'ground_truth': [item['requested_rewrite']['ground_truth'] for item in metrics],
            'subject': [item['requested_rewrite']['subject'] for item in metrics],
            'locality': {},
            'portability': {}                
        },
        'post': {
            'rewrite_acc': [sum(item['post']['rewrite_acc'][0] for item in metrics) / len(metrics) if metrics else 0],
            'rewrite_gen_content': [item['post']['rewrite_gen_content'][0] for item in metrics],
            'locality': {},
            'portability': {}
        }
    }
    merged_metrics['pre'].update(pre_loc)
    merged_metrics['post'].update(ret)
    if 'locality' in merged_metrics['post'].keys():
        for locality_key in list(merged_metrics['post']['locality'].keys()):
            locality_result = []
            for ans, label in zip(merged_metrics['post']['locality'][f'{locality_key}'], merged_metrics['pre']['locality'][f'{locality_key}']):
                locality_result.append(np.mean(np.equal(ans, label)))
            if locality_key.endswith('_output'):
                new_key = locality_key.replace('_output', '_acc')
            merged_metrics['post']['locality'][new_key] = locality_result
            merged_metrics['post']['locality'].pop(f'{locality_key}')
        merged_metrics['pre'].pop('locality')
    return merged_metrics

if __name__ == "__main__":
    parser = argparse.ArgumentParser()
    parser.add_argument('--editing_method', required=True, type=str)
    parser.add_argument('--hparams_dir', required=True, type=str)
    parser.add_argument('--data_dir', required=True, type=str)
    parser.add_argument('--metrics_save_dir', default='./output', type=str)
    parser.add_argument('--train_data_path', type=str)
    parser.add_argument('--rule_edit', default=True, type=str)
    parser.add_argument('--evaluation_type', default=None, type=str)
    parser.add_argument('--api_key', default=None, type=str)

    args = parser.parse_args()

    if args.editing_method == 'IKE':
        editing_hparams = IKEHyperParams
    elif args.editing_method == 'ICE':
        editing_hparams = IKEHyperParams
    elif args.editing_method == 'MEMIT':
        editing_hparams = MEMITHyperParams
    elif args.editing_method == 'ROME':
        editing_hparams = ROMEHyperParams
    elif args.editing_method == 'LoRA':
        editing_hparams = LoRAHyperParams
    elif args.editing_method == 'GRACE':
        editing_hparams = GraceHyperParams
    else:
        raise NotImplementedError

    hparams = editing_hparams.from_hparams(args.hparams_dir)
    hparams.rule_edit = args.rule_edit
    hparams.evaluation_type = args.evaluation_type
    hparams.api_key = args.api_key

    # === 加载数据集 ===
    with open(args.data_dir, "r") as f:
        dataset = json.load(f)

    all_metrics = []
    for id, data in enumerate(tqdm(dataset)):
        single_metrics = process_single_data(data, hparams)
        print("Edit metrics:", single_metrics)
        all_metrics.append(single_metrics)
    result_path = os.path.join(args.metrics_save_dir, f'{os.path.splitext(os.path.basename(args.data_dir))[0]}_{args.editing_method}_{hparams.model_name.split("/")[-1]}_results.json')
    json.dump(all_metrics, open(result_path, 'w'), indent=4)
    eval_instance(result_path)
