import pytorch_lightning as pl
from pytorch_lightning import loggers as pl_loggers, seed_everything
from pytorch_lightning.callbacks import ModelCheckpoint, LearningRateMonitor

from models import NeuralSearch
from models import CAM, IterRefLSTM
from models import MHNfs
from dataloader import FSMDataModule

import hydra
from omegaconf import OmegaConf


@hydra.main(config_path="../configs", config_name='config')
def train(config: OmegaConf):

    # Set seed
    seed_everything(config.training.seed)

    # Import data module
    dm = FSMDataModule(config)

    # Import model
    if config.model.name == 'neuralsearch':
        model = NeuralSearch(config)
    elif config.model.name == 'CAM':
        model = NeuralSearchTransformer(config)
    elif config.model.name == 'IterRefLSTM':
        model = IterRefNeuralSearch(config)
    elif config.model.name == 'MHNfs':
        model = NeuralSearch_Transformer_ChemSpaceRetrival_FullRefSet(config)
    else:
        raise ValueError('Model not supported! Please change model')

    # Define logger and callbacks
    logger = pl_loggers.WandbLogger(save_dir='../logs/', name=config.experiment_name)
    checkpoint_callback = ModelCheckpoint(monitor="dAUPRC_val", mode='max', save_top_k=3)
    lr_monitor = LearningRateMonitor(logging_interval='epoch')

    # Define and fit trainer
    if config.model.name != 'MHNfs':
        trainer = pl.Trainer(gpus=1, logger=logger, callbacks=[checkpoint_callback, lr_monitor],
                             max_epochs=config.training.epochs)
    else:
        trainer = pl.Trainer(gpus=1, logger=logger, callbacks=[checkpoint_callback, lr_monitor],
                             max_epochs=config.training.epochs, accumulate_grad_batches=4)

    trainer.fit(model, dm)


if __name__ == "__main__":
    train()

