import warnings
warnings.filterwarnings('ignore')

import sys
import torch
import random
import logging
import numpy as np
from tqdm import tqdm
import horovod.torch as hvd
from config import get_config
from models import load_model
from functools import partialmethod
from core.trainer import DockingTrainer
from core.dataset import get_dataloaders

handler = logging.StreamHandler(sys.stdout)
logging.getLogger().setLevel(logging.INFO)
logging.basicConfig(
    format="%(asctime)s %(message)s", datefmt="%y/%m/%d %H:%M:%S", handlers=[handler]
)


def set_seed(seed):
    random.seed(seed)
    np.random.seed(seed)
    torch.manual_seed(seed)
    if torch.cuda.is_available(): # GPU operations have separate seed
        torch.cuda.manual_seed(seed)
        torch.cuda.manual_seed_all(seed)
    torch.backends.cudnn.determinstic = True
    torch.backends.cudnn.benchmark = False


def main(config):
    
    # get rcsb dataloaders
    data_loaders = get_dataloaders(config, dataset_name='rcsb')

    # initialize model
    Model = load_model(f'{config.model}')
    model = Model(config)

    # initialize trainer
    trainer = DockingTrainer(data_loaders, model, config)
    trainer.train()


if __name__ == "__main__":
    # logging info
    logger = logging.getLogger()
    config = get_config()
    if not config.serial:
        hvd.init()
        torch.cuda.set_device(hvd.local_rank())
    if config.serial or hvd.rank() == 0:
        logging.info("===> Configurations")
        for arg in vars(config):
            logging.info("    {}: {}".format(arg, getattr(config, arg)))

    # set random seed
    set_seed(config.rand_seed)

    # mute tqdm for production runs
    if not config.unmute_tqdm:
        tqdm.__init__ = partialmethod(tqdm.__init__, disable=True)

    main(config)


