import pytorch_lightning as pl
from pytorch_lightning.callbacks import EarlyStopping, ModelCheckpoint

from mind_the_pad.data.mnist import preprocessor_mnist, letters_mnist_train_dl, letters_mnist_test_dl


def run_training(exprmn_path, model: pl.LightningModule, device, batch_size=32):

    preprocessing_image = preprocessor_mnist()
    train_dl = letters_mnist_train_dl(batch_size, preprocessing_image)
    test_dl = letters_mnist_test_dl(batch_size, preprocessing_image)

    epochs = 5000

    es = EarlyStopping('val/epoch_loss')
    checkpoint_saver = ModelCheckpoint(exprmn_path, 'model_{epoch}', save_last=True, every_n_epochs=5, save_weights_only=True)

    trainer = pl.Trainer(default_root_dir=exprmn_path, callbacks=[es, checkpoint_saver],
                         max_epochs=epochs, accelerator='auto', gpus=[0])

    trainer.fit(model, train_dl, test_dl)


