import argparse

from omegaconf import OmegaConf

from experiments.experiment_arf import Experiment_ARF
from experiments.experiment_cdtd import Experiment_CDTD
from experiments.experiment_ctgan import Experiment_CTGAN
from experiments.experiment_forestdiffusion import Experiment_ForestDiffusion
from experiments.experiment_tabbyflow import Experiment_TabbyFlow
from experiments.experiment_tabcascade import Experiment_TabCascade
from experiments.experiment_tabddpm import Experiment_TabDDPM
from experiments.experiment_tabdiff import Experiment_TabDiff
from experiments.experiment_tabsyn import Experiment_TabSyn
from experiments.experiment_tvae import Experiment_TVAE

parser = argparse.ArgumentParser()
parser.add_argument("dataset", help="dataset to be used", type=str)
parser.add_argument("cfg_path", help="path to config file", type=str)
parser.add_argument("mode", help="whether to train or eval the model", type=str, default="train")
parser.add_argument("seed", help="seed for the experiment", type=int, default=0)
parser.add_argument("-miss_mechanism", help="missingness mechanism for simulation", type=str, default=None)
# miss_mechanism: 'mnar' # 'mcar', 'mar', 'mnar', None
parser.add_argument("-p_miss", help="missingness rate", type=float, default=0.1)
parser.add_argument("-sampling_steps", help="overwrite config sampling steps", type=int, default=None)
parser.add_argument("--exp_path", help="subfolder in experiments folder in which results are saved", type=str)


def main(args):
    config = OmegaConf.load(args.cfg_path)
    model = config.model_name

    if model == "ctgan":
        experiment = Experiment_CTGAN(config, args)
    elif model == "tvae":
        experiment = Experiment_TVAE(config, args)
    elif model == "arf":
        experiment = Experiment_ARF(config, args)
    elif model == "tabddpm":
        experiment = Experiment_TabDDPM(config, args)
    elif model == "cdtd":
        experiment = Experiment_CDTD(config, args)
    elif model == "tabdiff":
        experiment = Experiment_TabDiff(config, args)
    elif model == "tabsyn":
        experiment = Experiment_TabSyn(config, args)
    elif model == "tabcascade":
        experiment = Experiment_TabCascade(config, args)
    elif model == "tabbyflow":
        experiment = Experiment_TabbyFlow(config, args)
    elif model == "forestdiffusion":
        experiment = Experiment_ForestDiffusion(config, args)

    if args.mode == "train":
        experiment.train(save_model=True)
        if (model == "arf") and config.model.use_as_lowres:
            pass
        else:
            experiment.sample_datasets()
    elif args.mode == "eval":
        experiment.evaluate()
    elif args.mode == "motivation":
        experiment.evaluate_motivation()
    elif args.mode == "tune_cdtd":
        assert model == "cdtd"
        experiment.tune_cdtd()


if __name__ == "__main__":
    args = parser.parse_args()
    main(args)
