from spaghettini import quick_register

from src.mains.task_getters import get_system_and_trainer


@quick_register
def train(cfg, cfg_dir, tmp_dir):
    # Get the pytorch lightning system and trainer.
    pl_system, trainer, checkpoint_found = get_system_and_trainer(cfg=cfg, cfg_path=cfg_dir,  tmp_dir=tmp_dir)

    # Train or continue training.
    trainer.fit(pl_system)

