# AUTOGENERATED! DO NOT EDIT! File to edit: nbs/00_train.ipynb (unless otherwise specified).

__all__ = ['train', 'train_parallel', 'CounterfactualModel2OptimizersW']

# Cell

from .import_essentials import *
from .utils import *
from .training_module import *
from .net import *
from pytorch_lightning.callbacks import EarlyStopping

# pl_logger = logging.getLogger('lightning')

# Cell


def train(module,
          t_configs: dict,
          logger_name: str = 'debug',
          description: str = 'default',
          debug: bool = False,
          logger=None
         ):
    # logger
    if logger is None:
        tb_logger = pl_loggers.TestTubeLogger(
            Path('../log/'),
            name=logger_name,
            description=description,
            debug=debug,
#             create_git_tag=True,
            log_graph=True
        )
    else:
        tb_logger = logger

    # checkpoint
    checkpoint_callback = ModelCheckpoint(
        monitor='val/val_loss',
        save_top_k=3,
        mode='min'
    )

    # train the model
    trainer = pl.Trainer(
        logger=tb_logger,
        checkpoint_callback=checkpoint_callback,
        **t_configs
    )

    pl_logger.info(f'hyper parameters: {module.hparams}')

    trainer.fit(module)
    return {'trainer': trainer, 'module': module}


def train_parallel(module,
                   t_configs: dict,
                   logger_name: str,
                   description: str = 'default',
                   debug: bool = False,
                   rounds: int = 3,
                   logger=None):
    return Parallel(n_jobs=-1, max_nbytes=None, verbose=True)(
        delayed(train)(
            cf_module=module,
            t_configs=t_configs,
            logger_name=logger_name,
            description=description,
            debug=debug,
            logger=logger
        )
        for i in range(rounds)
    )

# Cell

class CounterfactualModel2OptimizersW(CounterfactualModel2Optimizers):
    def configure_optimizers(self):
        opt_1 = torch.optim.AdamW([p for p in self.parameters() if p.requires_grad], lr=self.lr)
        opt_2 = torch.optim.AdamW([p for p in self.parameters() if p.requires_grad], lr=self.lr)
        # scheduler_1 = torch.optim.lr_scheduler.ReduceLROnPlateau(opt_1, patience=50)
        # scheduler_2 = torch.optim.lr_scheduler.ReduceLROnPlateau(opt_2, patience=50)
        # return (
        #     {'optimizer': opt_1, 'lr_scheduler': scheduler_1, 'monitor': 'val/val_loss'},
        #     {'optimizer': opt_2, 'lr_scheduler': scheduler_2, 'monitor': 'val/val_loss'}
        # )
        return (opt_1, opt_2)

if __name__ == "__main__" and not in_jupyter():
    dummy_config = json.load(open("counterfactual/configs/dummy.json"))
    adult_config = json.load(open("counterfactual/configs/adult.json"))
    student_config = json.load(open("counterfactual/configs/student.json"))
    home_config = json.load(open("counterfactual/configs/home.json"))

    t_config = json.load(open("counterfactual/configs/trainer.json"))

    # training for dummy data
#     train(
#         cf_module=BaselineModel(dummy_config),
#         t_configs=t_config,
#         logger=pl_loggers.TestTubeLogger(Path('log/'), name="dummy/baseline")
#     )

#     train(
#         cf_module=CounterfactualModel(dummy_config),
#         t_configs=t_config,
#         logger=pl_loggers.TestTubeLogger(Path('log/'), name="dummy/cf")
#     )

#     train(
#         cf_module=CounterfactualModel2Optimizers(dummy_config),
#         t_configs=t_config,
#         logger=pl_loggers.TestTubeLogger(Path('log/'), name="dummy/cf_2opt")
#     )

    # training for adult data
#     train(
#         cf_module=BaselineModel(adult_config),
#         t_configs=t_config,
#         logger=pl_loggers.TestTubeLogger(Path('log/'), name="adult/baseline")
#     )

    # train(
    #     module=CounterfactualModel(adult_config),
    #     t_configs=t_config,
    #     logger=pl_loggers.TestTubeLogger(Path('log/'), name="adult/cf")
    # )

    # for lr in [0.03, 0.02, 0.01]:
    #     adult_config['lr'] = lr
    #     train(
    #         module=CounterfactualModel2OptimizersW(adult_config),
    #         t_configs=t_config,
    #         logger=pl_loggers.TestTubeLogger(Path('log/'), name="adult/115/cf_2opt_adamw")
    #     )

    # train(
    #     module=CounterfactualModel2OptimizersScheduler(adult_config),
    #     t_configs=t_config,
    #     logger=pl_loggers.TestTubeLogger(Path('log/'), name="adult/115/cf_2opt_mse_scheduler")
    # )

    # for grad_clip in [0, 0.1, 0.2, 0.5, 0.7, 0.8]:
    #     t_config['gradient_clip_val'] = grad_clip
    #     train(
    #         module=CounterfactualModel2Optimizers(adult_config),
    #         t_configs=t_config,
    #         logger=pl_loggers.TestTubeLogger(Path('log/'), name="adult/115/cf_2opt_mse_clip")
    #     )

    # train for student dataset
    # for lr in [3e-3, 1e-3, 5e-4, 3e-4, 1e-4]:
    #     student_config['lr'] = lr
    #     train(
    #         module=CounterfactualModel2Optimizers(student_config),
    #         t_configs=t_config,
    #         logger=pl_loggers.TestTubeLogger(Path('log/'), name="student/116/cf_2opt")
    #     )

    # train baselines
    # train(
    #     module=BaselineModel(adult_config),
    #     t_configs=t_config,
    #     logger=pl_loggers.TestTubeLogger(Path('log/'), name="baseline/adult")
    # )

    # train(
    #     module=BaselineModel(student_config),
    #     t_configs=t_config,
    #     logger=pl_loggers.TestTubeLogger(Path('log/'), name="baseline/student")
    # )

    # train(
    #     module=BaselineModel(home_config),
    #     t_configs=t_config,
    #     logger=pl_loggers.TestTubeLogger(Path('log/'), name="baseline/home")
    # )