import torch
import argparse
import time
from parsers.parser import Parser
from parsers.config import get_config
from trainer import Trainer
from sampler import Sampler, Sampler_mol


def main(work_type_args):
    ts = time.strftime('%b%d-%H-%M-%S', time.gmtime())
    args = Parser().parse()
    config = get_config(args.config, args.seed)
    config.type = args.beta_type
    config.checkpoint = args.checkpoint

    # -------- Train --------
    if work_type_args.type == 'train':
        trainer = Trainer(config)
        ckpt = trainer.train(ts)
        if 'sample' in config.keys():
            config.ckpt = ckpt
            if config.data.data in ["community_small", "grid", "ENZYMES", "ego_small", "planar", "sbm", 'tree', 'cora', 'citeseer', 'pubmed']:
                sampler = Sampler(config)
            elif config.data.data in ["QM9", "ZINC250k"]:
                sampler = Sampler_mol(config)

            sampler.sample()

    # -------- Generation --------
    elif work_type_args.type == 'sample':
        print("in sample mode")
        print("config.data.data:", config.data.data)
        config.ckpt = config.checkpoint
        if config.data.data in ["community_small", "grid", "ENZYMES", "ego_small", "planar", "sbm", 'tree', 'cora', 'citeseer', 'pubmed']:
            sampler = Sampler(config)
        elif config.data.data in ["QM9", "ZINC250k"]:
            sampler = Sampler_mol(config)
        sampler.sample()
    else:
        raise ValueError(f'Wrong type : {work_type_args.type}')

if __name__ == '__main__':

    work_type_parser = argparse.ArgumentParser()
    work_type_parser.add_argument('--type', type=str, default="sample", help="train or sample")
    main(work_type_parser.parse_known_args()[0])
