from baseline.easyeditor import BaseEditor
from baseline.easyeditor import (
    FTHyperParams, 
    IKEHyperParams, 
    KNHyperParams, 
    MEMITHyperParams, 
    ROMEHyperParams, 
    LoRAHyperParams,
    MENDHyperParams,
    MENDTrainingHparams,
    SERACHparams,
    SERACTrainingHparams,
    GraceHyperParams,
    WISEHyperParams,
    summary_metrics,
    ZsreDataset,
    EditTrainer,
)
from omegaconf import DictConfig, OmegaConf
from src.datasets import load_dataset
from pathlib import Path
import numpy as np
import hydra
import json
import os
import torch
import gc

@hydra.main(version_base=None)
def run(cfg: DictConfig):
    # prepare output
    editing_method = cfg.editing_method
    run_dir = Path(f'results/run_{editing_method}_{os.getpid()}')
    run_dir.mkdir(parents=True, exist_ok=True)
    (run_dir / 'params.yaml').write_text(OmegaConf.to_yaml(cfg))

    edits, rephrases, locs, subject = load_dataset(cfg.dataset, return_subject=True)
    edits, rephrases, locs = np.array(edits), np.array(rephrases), np.array(locs)

    prompts = edits[:, 0].tolist()
    rephrase_prompts = rephrases[:, 0].tolist()
    locality_prompts = locs[:, 0].tolist()
    loc_prompts = [x[0] + x[1] for x in locs]

    target_new = [x[1:] for x in edits[:, 1]]
    locality_ans = [x[1:] for x in locs[:, 1]]

    locality_inputs = {
        'neighborhood':{
            'prompt': locality_prompts,
            'ground_truth': locality_ans
        },
    }

    sequential_edit = cfg.editor.seq
    editing_hparams = None

    if editing_method == 'FT':
        editing_hparams = FTHyperParams
    elif editing_method == 'GRACE':
        editing_hparams = GraceHyperParams
    elif editing_method == 'MEMIT':
        editing_hparams = MEMITHyperParams
    elif editing_method == 'WISE':
        editing_hparams = WISEHyperParams
    elif editing_method == 'MEND':
        editing_hparams = MENDHyperParams
    elif editing_method == 'SERAC':
        editing_hparams = SERACHparams
    else: raise ValueError(editing_method)

    hparams = editing_hparams.from_hparams(
        f'baseline/hparams/{editing_method}/{cfg.model.name}.yaml'
    )
    editor = BaseEditor.from_hparams(hparams)
    metrics, edited_model, _ = editor.edit(
        prompts=prompts,
        rephrase_prompts=rephrase_prompts,
        target_new=target_new,
        loc_prompts=loc_prompts,
        subject=subject,
        #train_ds=train_ds,
        locality_inputs=locality_inputs,
        sequential_edit=sequential_edit,
        keep_original_weight=True
    )

    (run_dir / f'{editing_method}.json').write_text(json.dumps(metrics, indent=2))
    if len(metrics) > 0:
        (run_dir / f'{editing_method}_summary.json').write_text(json.dumps(summary_metrics(metrics), indent=2))

if __name__ == '__main__':
    run()