import argparse
from omegaconf import OmegaConf
import os
from experiments import *

parser = argparse.ArgumentParser()
parser.add_argument("dataset", help = "dataset to be used", type = str)
parser.add_argument("cfg_path", help = "path to config file", type = str)
parser.add_argument("mode", help = "whether to train or eval the model", type = str, default = 'train')
parser.add_argument("seed", help = "seed for the experiment", type = int, default = 0)
parser.add_argument("-miss_mechanism", help = "missingness mechanism for simulation", type = str, default = None)
# miss_mechanism: 'mnar' # 'mcar', 'mar', 'mnar', None
parser.add_argument("--exp_path", help = "subfolder in experiments folder in which results are saved", type = str)



# args = argparse.Namespace(dataset = 'adult',
#                           cfg_path = 'experiments/configs/tabcascade/default.yaml',
#                           seed = 0,
#                           miss_mechanism = 'mnar',
#                           mode = 'train',
#                           exp_path = 'test')

def main(args):

    config = OmegaConf.load(args.cfg_path)
    model = config.model_name
    
    if model == 'ctgan':
        experiment = Experiment_CTGAN(config, args) 
    elif model == 'tvae':
        experiment = Experiment_TVAE(config, args)
    elif model == 'arf':
        experiment = Experiment_ARF(config, args)
    elif model == 'tabddpm':
        experiment = Experiment_TabDDPM(config, args)
    elif model == 'cdtd':
        experiment = Experiment_CDTD(config, args)
    elif model == 'tabdiff':
        experiment = Experiment_TabDiff(config, args)
    elif model == 'tabsyn':
        experiment = Experiment_TabSyn(config, args)
    elif model == 'tabcascade':
        experiment = Experiment_TabCascade(config, args)
    elif model == 'lowres':
        experiment = Experiment_LowRes(config, args)
        
    if args.mode == "train":
        experiment.train(save_model=True)
        experiment.evaluate()
    elif args.mode == "eval":
        experiment.evaluate()
    elif args.mode == 'motivation':
        experiment.train(save_model=True)
        experiment.evaluate_motivation()
    elif args.mode =='tune_cdtd':
        assert model == 'cdtd'
        experiment.tune_cdtd()

if __name__ == "__main__":
    args = parser.parse_args()
    main(args)