from baseline.easyeditor import (
    MENDTrainingHparams,
    SERACTrainingHparams,
    ZsreDataset,
    CounterFactDataset,
    EditTrainer,
)
from omegaconf import DictConfig 
import hydra
import logging

logging.basicConfig(level=logging.INFO)

@hydra.main(version_base=None)
def run(cfg: DictConfig):
    editing_method = cfg.editing_method
    training_hparams = None

    if editing_method == 'MEND':
        training_hparams = MENDTrainingHparams
    elif editing_method == 'SERAC':
        training_hparams = SERACTrainingHparams

    if training_hparams:
        training_hparams = training_hparams.from_hparams(
            f'baseline/hparams/TRAINING/{editing_method}/{cfg.model.name}.yaml'
        )
        Dataset = None
        if cfg.dataset.name == 'qa':
            Dataset = ZsreDataset
        if cfg.dataset.name == 'cf':
            Dataset = CounterFactDataset
        assert Dataset

        train_ds = Dataset(f'processed/{cfg.dataset.name}/train_{cfg.dataset.n_edits}.json', config=training_hparams)
        val_ds = Dataset(f'processed/{cfg.dataset.name}/val_{cfg.dataset.n_edits}.json', config=training_hparams)
        trainer = EditTrainer(
            config=training_hparams,
            train_set=train_ds,
            val_set=val_ds,
        )
        trainer.run()
    else:
        logging.info('No trainer')

if __name__ == '__main__':
    run()