import argparse
from copy import deepcopy
from pprint import pprint
from typing import List

import yaml

from experiments import ExperimentHpars
from models import ModelHpars


def get_validation_configs(base_conf: dict) -> List[dict]:
    new_configs = []
    for model in ModelHpars.OPTIONS:
        if not base_conf["continual"]:
            for adversarial in [False, True]:
                new_config = deepcopy(base_conf)
                new_config["model_hpars"]["name"] = model
                new_config["adversarial"] = adversarial
                new_configs.append(new_config)
        else:
            new_config = deepcopy(base_conf)
            new_config["model_hpars"]["name"] = model
            new_configs.append(new_config)
    return new_configs


if __name__ == "__main__":
    parser = argparse.ArgumentParser(description='Train and evaluate model.')
    parser.add_argument("-cf", "--config_file", metavar="config_file", type=str, required=True,
                        help="Path to the YAML config file containing the parameter settings.")
    parser.add_argument("-v", "--validate", action="store_true", required=False, default=False,
                        help="Whether to run hyperparameter validation.")
    args = parser.parse_args()
    config_filepath, validate = args.config_file, args.validate
    with open(config_filepath, "r", encoding="utf-8") as config_file:
        base_conf = yaml.safe_load(config_file)

    experiment_configs = get_validation_configs(base_conf) if validate else [base_conf, ]
    for experiment_config in experiment_configs:
        experiment_hpars = ExperimentHpars.from_dict(experiment_config)
        pprint(experiment_config)
        experiment = experiment_hpars.make()
        # Save config
        with open(experiment.output_folder + "/config.yml", "w", encoding="utf-8") as config_file:
            yaml.safe_dump(experiment_config, config_file)
        experiment.run()
