import pytorch_lightning as pl
import torch
from pytorch_lightning.callbacks import ModelCheckpoint

import gcip.utils.io as playbook_io
from gcip.utils.my_early_stopping import MyEarlyStopping
from pytorch_lightning.callbacks import EarlyStopping

def load_trainer(cfg,
                 dirpath,
                 logger_dir=None,
                 include_logger=True,
                 model_checkpoint=True,
                 cfg_early=None,
                 preparator=None):

    if logger_dir is None:
        logger_dir = dirpath
    devices = None
    if torch.cuda.is_available() and cfg.device  in ['gpu', 'auto']:
        devices = torch.cuda.device_count()

        playbook_io.print_info(f'Using {devices} GPUs')
        for i in range(devices):
            playbook_io.print_info(f'[{i}] {torch.cuda.get_device_name(i)}')

        if devices == 0: devices = 1
    callbacks = []
    if preparator is not None:
        monitor, mode = preparator.monitor()
    else:
        monitor = None
        mode = 'min'

    if model_checkpoint:
        model_checkpoint = ModelCheckpoint(dirpath=dirpath,
                                           every_n_epochs=None,
                                           save_last=True,
                                           save_top_k=1,
                                           monitor=monitor,
                                           mode=mode,
                                           save_weights_only=True)
        callbacks.append(model_checkpoint)
    if cfg_early is not None and cfg_early.activate:


        early_stop_callback = EarlyStopping(monitor=monitor,
                                              min_delta=cfg_early.min_delta,
                                              patience=cfg_early.patience,
                                              verbose=cfg_early.verbose,
                                              mode=mode,
                                              check_on_train_epoch_end=False
                                              )
        callbacks.append(early_stop_callback)

    from pytorch_lightning.loggers import CSVLogger
    if include_logger:
        logger = CSVLogger(save_dir=logger_dir,
                           name='logs')
    else:
        logger = None
    if "cuda" in cfg.device:
        accelerator_type = "cuda"
    else:
        accelerator_type = "cpu"
    trainer = pl.Trainer(
        default_root_dir=cfg.root_dir,
        callbacks=callbacks,
        logger=logger,
        deterministic=False,
        devices=devices,
        auto_select_gpus=True,
        accelerator=accelerator_type,
        auto_scale_batch_size=cfg.train.auto_scale_batch_size,
        max_epochs=cfg.train.max_epochs,
        profiler=cfg.train.profiler,
        enable_progress_bar=cfg.train.enable_progress_bar,
        max_time=cfg.train.max_time,
        # auto_lr_find=cfg.train.auto_lr_find,
        limit_train_batches=cfg.train.limit_train_batches,
        limit_val_batches=cfg.train.limit_val_batches,
        fast_dev_run=False,
        inference_mode=cfg.train.inference_mode,
    )
    if include_logger:
        return trainer, logger
    else:
        return trainer
