import os.path
import sys
import json
import argparse
import torch
from pathlib import Path
import os 
# os.environ["CUDA_VISIBLE_DEVICES"] = "0" 
sys.path.append('..')
from easyeditor import (
    AlphaEditHyperParams,
    FTHyperParams,
    GraceHyperParams,
    MEMITHyperParams,
    ROMEHyperParams,
    MENDHyperParams,
    WISEHyperParams,
    BaseEditor,
    PMETHyperParams,
    KNHyperParams,
    summary_metrics,
)


REMOTE_ROOT_URL = "https://memit.baulab.info/data/dsets"
URL_DICT = {
    "MCF" : f"{REMOTE_ROOT_URL}/multi_counterfact.json",
    "CF" : f"{REMOTE_ROOT_URL}/counterfact.json",
    "ZsRE" : f"{REMOTE_ROOT_URL}/zsre_mend_eval.json",
    "ZsRE_train" : f"{REMOTE_ROOT_URL}/zsre_mend_train.json"
}

# python run_AlphaEdit_llama3.py \
#     --editing_method=AlphaEdit \
#     --hparams_dir=../hparams/AlphaEdit/llama3-8b.yaml \
#     --data_dir=../data/ZsRE \
#     --ds_size=100 \
#     --data_type=ZsRE \
#     --sequential_edit

# def set_args():
#     class Args:
#         editing_method = 'AlphaEdit'
#         hparams_dir = '../hparams/AlphaEdit/llama-7b.yaml' 
#         data_dir = '../data/ZsRE'
#         data_type = 'ZsRE'
#         output_dir = './outputs'
#         ds_size = 500
#         sequential_edit = True
#     return Args()
 
# args = set_args()
 
if __name__ == "__main__": 
    parser = argparse.ArgumentParser()
    parser.add_argument('--editing_method', required=True, type=str)
    parser.add_argument('--hparams_dir', required=True, type=str)
    parser.add_argument('--data_dir', required=True, type=str)
    parser.add_argument('--data_type', required=True, type=str,
                        choices=['ZsRE', 'MCF', 'CF'])
    parser.add_argument('--output_dir', default='./outputs', type=str)
    parser.add_argument('--ds_size', default=3, type=int)
    parser.add_argument('--sequential_edit', action="store_true")

    args = parser.parse_args()


    if args.editing_method == 'FT':
        editing_hparams = FTHyperParams
    elif args.editing_method == 'KN':
        editing_hparams = KNHyperParams
    elif args.editing_method == 'MEMIT':
        editing_hparams = MEMITHyperParams
    elif args.editing_method == 'PMET':
        editing_hparams = PMETHyperParams
    elif args.editing_method == 'ROME':
        editing_hparams = ROMEHyperParams
    elif args.editing_method == 'MEND':
        editing_hparams = MENDHyperParams
    elif args.editing_method == 'GRACE':
        editing_hparams = GraceHyperParams
    elif args.editing_method == 'WISE':
        editing_hparams = WISEHyperParams
    elif args.editing_method == 'AlphaEdit':
        editing_hparams = AlphaEditHyperParams
    else:
        raise NotImplementedError

    url = URL_DICT[args.data_type]
    data_dir = Path(args.data_dir)
    K = args.ds_size
    if args.data_type == 'ZsRE':
        # zsre_loc = data_dir / 'zsre_mend_eval.json'
        zsre_loc = data_dir / 'zsre_mend_train_10000.json'
        if not zsre_loc.exists():
            print(f"{zsre_loc} does not exist. Downloading from {url}")
            torch.hub.download_url_to_file(url, zsre_loc)
        with open(zsre_loc, "r") as f:
            raw = json.load(f)
        raw = [item for item in raw if item.get('alt', '').strip() != '']
    
        loc_data = json.load(open('./code/KnowledgeEdit/AlphaEdit/data/ZsRE/zsre_mend_train.json', 'r', encoding='utf-8'))[:args.ds_size]
        loc_prompts = [edit_data_['loc'] + ' ' + edit_data_['loc_ans'] for edit_data_ in loc_data]

        edit_data = raw[:args.ds_size]
        prompts = [edit_data_['src'] for edit_data_ in edit_data]
        subject = [edit_data_['subject'] for edit_data_ in edit_data]
        rephrase_prompts = [edit_data_['rephrase'] for edit_data_ in edit_data]
        target_new = [edit_data_['alt'] for edit_data_ in edit_data]
        locality_prompts = [edit_data_['loc'] for edit_data_ in edit_data]
        locality_ans = [edit_data_['loc_ans'] for edit_data_ in edit_data]
        locality_inputs = {
            'neighborhood':{
                'prompt': locality_prompts,
                'ground_truth': locality_ans
            },
        } 
    elif args.data_type == 'CF':
        cf_loc = data_dir / 'counterfact.json'
        if not cf_loc.exists():
            print(f"{cf_loc} does not exist. Downloading from {url}")
            torch.hub.download_url_to_file(url, cf_loc)
        with open(cf_loc,"r") as f:
            raw = json.load(f)
        edit_data = raw[:args.ds_size]
        prompts = [edit_data_['requested_rewrite']['prompt'].format(edit_data_['requested_rewrite']['subject']) for edit_data_ in edit_data]
        subject = [edit_data_['requested_rewrite']['subject'] for edit_data_ in edit_data]
        rephrase_prompts = [edit_data_['paraphrase_prompts'][0] for edit_data_ in edit_data]
        target_new = [edit_data_['requested_rewrite']['target_new']['str'] for edit_data_ in edit_data]
        locality_prompts = [edit_data_['neighborhood_prompts'] for edit_data_ in edit_data]
        locality_ans = [[edit_data_['requested_rewrite']["target_true"]["str"]] * len(edit_data_["neighborhood_prompts"]) for edit_data_ in edit_data]
        locality_inputs = {
            'neighborhood':{
                'prompt': locality_prompts,
                'ground_truth': locality_ans
            },
        }
    elif args.data_type == 'MCF':
        mcf_loc = data_dir / 'multi_counterfact.json'
        if not mcf_loc.exists():
            print(f"{mcf_loc} does not exist. Downloading from {url}")
            torch.hub.download_url_to_file(url, mcf_loc)
        with open(mcf_loc,"r") as f:
            raw = json.load(f)
        edit_data = raw[:args.ds_size]
        prompts = [edit_data_['requested_rewrite']['prompt'].format(edit_data_['requested_rewrite']['subject']) for edit_data_ in edit_data]
        subject = [edit_data_['requested_rewrite']['subject'] for edit_data_ in edit_data]
        rephrase_prompts = [edit_data_['paraphrase_prompts'][0] for edit_data_ in edit_data]
        target_new = [edit_data_['requested_rewrite']['target_new']['str'] for edit_data_ in edit_data]
        locality_prompts = [edit_data_['neighborhood_prompts'] for edit_data_ in edit_data]
        locality_ans = [[edit_data_['requested_rewrite']["target_true"]["str"]] * len(edit_data_["neighborhood_prompts"]) for edit_data_ in edit_data]
        locality_inputs = {
            'neighborhood':{
                'prompt': locality_prompts,
                'ground_truth': locality_ans
            },
        }

    hparams = editing_hparams.from_hparams(f'{args.hparams_dir}')

    args.pre_file = f"./outputs/prefile/{args.editing_method}/{hparams.model_name.split('/')[-1]}_{args.data_type}_N={args.ds_size}_Sequential={args.sequential_edit}_pre_edit.json"
    os.makedirs(os.path.dirname(args.pre_file), exist_ok=True)
    print("pre_file path:", args.pre_file)
    if args.pre_file is not None and os.path.exists(args.pre_file):
        pre_edit = json.load(open(args.pre_file,'r'))
        assert len(pre_edit) == len(prompts)
    else:
        pre_edit = None
    
    os.makedirs(args.output_dir, exist_ok=True)
    os.makedirs(args.output_dir + f'/{args.editing_method}', exist_ok=True)
    output_file = os.path.join(
        args.output_dir,
        f'{args.editing_method}/{hparams.model_name.split("/")[-1]}_{args.editing_method}_N={args.ds_size}_Sequential={args.sequential_edit}_base.json'
        )

    print("See results at: ", output_file)
    extra_kwargs = {
        "pre_file": args.pre_file,
        "pre_edit": pre_edit,
        "args": args,
    }
    editor = BaseEditor.from_hparams(hparams)
    metrics, edited_model, _ = editor.edit(
        prompts=prompts,
        rephrase_prompts=rephrase_prompts,
        target_new=target_new,
        loc_prompts=loc_prompts,
        subject=subject,
        locality_inputs=locality_inputs,
        keep_original_weight=True,
        sequential_edit=args.sequential_edit,
        test_generation=True,
        **extra_kwargs
    )
 
    if len(metrics) > 0:
        with open(output_file, 'w') as f:
            output_data = {
                "detailed_metrics": metrics,
                "summary_metrics": summary_metrics(metrics)
            }
            json.dump(output_data, f, indent=4)
        print("Summary metrics saved to:", output_file)    
 
