import pytorch_lightning as pl
import src.constants as cst


def callback_save_model(config, run_name):
    monitor_var = config.EARLY_STOPPING_METRIC
    check_point_callback = pl.callbacks.ModelCheckpoint(
        monitor=monitor_var,
        verbose=True,
        save_top_k=1,
        mode='max',
        dirpath=cst.DIR_SAVED_MODEL + config.WANDB_SWEEP_NAME,
        filename=config.WANDB_SWEEP_NAME + "-run=" + run_name + "-{epoch}-{" + monitor_var + ':.5f}'
    )
    return check_point_callback


def early_stopping(config):
    """ Stops if models stops improving. """
    monitor_var = config.EARLY_STOPPING_METRIC
    if monitor_var == 'trainingloss':
        delta = 1
        mode = 'min'
    elif monitor_var == 'validation-epoch-last_FI_f1':
        delta = 0.005
        mode = 'max'
    else:
        print(f'no valid metric chosen...')
    if config.CHOSEN_DATASET == cst.DatasetFamily.CHF:
        p = 2
    elif config.CHOSEN_DATASET == cst.DatasetFamily.FI:
        p = 7
    elif config.CHOSEN_DATASET == cst.DatasetFamily.BTC:
        p = 4
    else:
        print("did not specify dataset correctly. patience not chosen")
        exit()
    return pl.callbacks.EarlyStopping(
        monitor=monitor_var,
        min_delta=delta,
        patience=p,
        verbose=True,
        mode=mode,
        # |v stops when if after epoch 1, the
        # check_on_train_epoch_end=True,
        # divergence_threshold=1/3,
    )

def new_progress_bar():
    """Use RichProgressBar"""
    return pl.callbacks.RichProgressBar(
        refresh_rate = 10
    )