import argparse
import time
from parsers.parser import Parser
from parsers.config import get_config

from rdkit import RDLogger
RDLogger.DisableLog('rdApp.*')


def main(work_type_args):
    ts = time.strftime('%b%d-%H:%M:%S', time.gmtime())
    args = Parser().parse()
    config = get_config(args.config, args.gpu, args.seed)

    if work_type_args.type == 'train':
        if args.config == 'classifier':
            from ctrainer import Trainer_classifier
            trainer = Trainer_classifier(config)
        else:
            from trainer import Trainer
            trainer = Trainer(config)
        trainer.train(ts)

    elif work_type_args.type == 'sample':
        if args.config == 'sample_cond':
            from sampler import Sampler_conditional
            sampler = Sampler_conditional(config)
        else:
            from sampler import Sampler_mol
            sampler = Sampler_mol(config)
        sampler.sample()

    else:
        raise ValueError(f'Wrong type : {work_type_args.type}')


if __name__ == '__main__':
    work_type_parser = argparse.ArgumentParser()
    work_type_parser.add_argument('-t', '--type', type=str, required=True)

    main(work_type_parser.parse_known_args()[0])
