import pytorch_lightning as pl
from pytorch_lightning.callbacks import ModelCheckpoint


def train(model, cfg, results_save_path, device_number=0):
    checkpoint_callback = ModelCheckpoint(
        save_top_k=-1, every_n_val_epochs=cfg.epochs, filename='{step}')
    trainer = pl.Trainer(
        default_root_dir=results_save_path, max_epochs=cfg.epochs * cfg.iter_proxi, callbacks=[checkpoint_callback], gpus=[device_number]
        # fast_dev_run=2
    )
    trainer.fit(model)
