from torch.utils.data import DataLoader
from src.datasets import load_dataset
from src.models import load_model
from src.gate import train_gate
from src.editor import Editor
from omegaconf import DictConfig, OmegaConf
from pathlib import Path
import numpy as np
import os
import json
import hydra
import logging

logger = logging.getLogger(__name__)

@hydra.main(version_base=None, config_path='conf', config_name='config')
def do_training(cfg: DictConfig) -> None:
    if cfg.task != 'gate':
        return
    logger.info('do training')
    train_gate(cfg.experiment.gate)

@hydra.main(version_base=None, config_path='conf', config_name='config')
def do_edit(cfg: DictConfig) -> None:
    if cfg.task != 'edit':
        return
    logger.info('do edit')
    metrics = {
        'reliability': [],
        'generality': [],
        'locality': [],
    }
    # prepare output
    run_dir = Path(f'results/run_MLP_{os.getpid()}')
    run_dir.mkdir(parents=True, exist_ok=True)
    (run_dir / 'params.yaml').write_text(OmegaConf.to_yaml(cfg))
    # do editing
    edits, rephrases, locs = load_dataset(cfg.experiment.dataset)
    model, tokenizer, sim = load_model(cfg.experiment.model)
    editor = Editor(model, tokenizer, sim, cfg.experiment.editor)
    # either sequential or one-by-one
    if cfg.experiment.editor.seq:
        edit_loader = DataLoader(edits, batch_size=5)
        rephrase_loader = DataLoader(rephrases, batch_size=5)
        loc_loader = DataLoader(locs, batch_size=5)
        editor.edit(edit_loader)
        metrics['reliability'] += editor.test(edit_loader)
        metrics['generality'] += editor.test(rephrase_loader)
        metrics['locality'] += editor.test(loc_loader, is_local=True)
    else:
        for edit, rephrase, loc in zip(edits, rephrases, locs):
            edit_loader = DataLoader([edit], batch_size=1)
            rephrase_loader = DataLoader([rephrase], batch_size=1)
            loc_loader = DataLoader([loc], batch_size=1)
            editor.edit(edit_loader)
            metrics['reliability'] += editor.test(edit_loader)
            metrics['generality'] += editor.test(rephrase_loader)
            metrics['locality'] += editor.test(loc_loader, is_local=True)
    # write result
    (run_dir / 'MLP_summary.json').write_text(json.dumps(metrics, indent=2))
    for x in metrics:
        logging.info(f'{x}: {round(np.mean(metrics[x]), 2)}')

if __name__ == '__main__':
    do_training()
    do_edit()
