import os
import logging
import hydra
from omegaconf import DictConfig, OmegaConf
from ising_runner import Runner
from ising_ardm_runner import ARDMRunner


import torch.multiprocessing

torch.multiprocessing.set_sharing_strategy('file_system')

@hydra.main(config_path="config", config_name="config_mol_mb")
def main(cfg: DictConfig):
    # get data path with absolute path
    cfg.data_dir = hydra.utils.to_absolute_path(cfg.data_dir)
    cfg.gen_dir = hydra.utils.to_absolute_path(cfg.gen_dir)
    # relative to hydra path
    os.makedirs(cfg.model_dir, exist_ok=True)
    os.makedirs(cfg.log_dir, exist_ok=True)
    logging.info(os.getcwd())
    logging.info(OmegaConf.to_yaml(cfg))

    if cfg.model == 'ARDM':
        runner = ARDMRunner(cfg)
    elif cfg.model == 'MarNet':
        runner = Runner(cfg)
    else:
        raise NotImplementedError
    if cfg.mode == 'train':
        runner.train()
    elif cfg.mode == 'generate':
        runner.generate()
    elif cfg.mode == 'gen_mol':
        runner.generate_mols()
    elif cfg.mode == 'test':
        runner.test()
    elif cfg.mode == 'evaluate':
        runner.eval_kl()
    else:
        raise NotImplementedError

if __name__ == "__main__":
    main()
    
