import yaml

import sys
sys.path.append('.')

from lib.trainers.canny_aug_trainer import CannyAugTrainer
from lib.utils.misc import update_nested_dict, unflatten_dict, Dict


def main(config_file, **kwargs):
    with open(config_file, 'r') as f:
        conf = yaml.safe_load(f)

    overrides = unflatten_dict(kwargs)

    final_config = update_nested_dict(conf, overrides)
    print(f'Overriding CLI flags to default yaml config: {overrides}')

    args = Dict(final_config)

    trainer = CannyAugTrainer(args=args)

    trainer.prepare_dataset()
    trainer.prepare_models()
    trainer.prepare_trainable_parameters()
    trainer.prepare_optimizer()
    trainer.prepare_for_training()
    trainer.prepare_trackers()

    trainer.train()


if __name__ == '__main__':
    import fire
    fire.Fire(main)
