import elements
from trainers.dreamer_trainer import DreamerTrainer
from trainers.on_policy_trainer import OnPolicyTrainer
from utils.tools import set_seed, get_task_name

import ruamel.yaml as yaml

import sys
from typing import List

def main(argv: List[str]):
    elements.print(r"---   ____  __  __    ___        ____  __  ---")
    elements.print(r"---  |  _ \|  \/  |  / \ \      / /  \/  | ---")
    elements.print(r"---  | | | | |\/| | / _ \ \ /\ / /| |\/| | ---")
    elements.print(r"---  | |_| | |  | |/ ___ \ V  V / | |  | | ---")
    elements.print(r"---  |____/|_|  |_/_/   \_\_/\_/  |_|  |_| ---")

    # parse config from command line
    config, remaining = elements.Flags(env="meltingpot", trainer="dreamer").parse_known(argv)
    # load trainer specific config
    trainer_config = yaml.YAML(typ='safe').load((elements.Path(__file__).parent / f'configs/trainer_configs/{config.trainer}.yaml').read())
    config = config.update(trainer_config)
    # load env specific config
    env_config = yaml.YAML(typ='safe').load((elements.Path(__file__).parent / f'configs/env_configs/{config.env}.yaml').read())
    config = config.update(env_config)
    # load config from command line
    config = elements.Flags(config).parse(remaining)

    # save config file
    task_name = get_task_name(config)
    logdir = elements.Path(__file__).parent / 'logs' / config.env / task_name / config.name / elements.timestamp()
    config = config.update(logdir=logdir)
    logdir.mkdir()
    config.save(logdir / 'config.yaml')

    # set seed
    set_seed(config)

    # set up timer
    elements.timer.global_timer.enabled = config.logging.timer

    # start training
    if config.trainer == "dreamer":
        trainer = DreamerTrainer(config)
    elif config.trainer == "on_policy":
        trainer = OnPolicyTrainer(config)
    else:
        raise ValueError(f"Trainer {config.trainer} not found")
    trainer.train()

if __name__ == "__main__":
    main(sys.argv[1:])