import argparse
from omegaconf import OmegaConf
import os
from experiments import Experiment_CDTD, Experiment_TabDiff, Experiment_TabDDPM

parser = argparse.ArgumentParser()
parser.add_argument("--data", type=str, required=True, help="dataset name")
parser.add_argument("--model", type=str, required=True, help="generative model: cdtd, tabdiff, tabddpm")
parser.add_argument("--mode", type=str, default="train", help="train or eval")
parser.add_argument("--cfg_path", type=str, help="path to yaml config")
parser.add_argument("--exp_path", type=str, help="subfolder for results")
parser.add_argument("--gpu", type=int, required=True)
parser.add_argument("--preproc", type=str, default="", help="preprocessing / imputation method")
parser.add_argument("--cov", type=str, default="both", help="noise coverage: both, x_only, y_only")
parser.add_argument("--pattern", type=str, default="NU1", help="missingness pattern")
parser.add_argument("--p", type=float, default=0.0, help="missingness ratio")
parser.add_argument("--noise_seed", type=int, default=0)
parser.add_argument("--extra", type=str, default="", help="run tag, e.g. aug_full, aug_mask")
parser.add_argument("--strategy", type=int, default=0, help="0=aug_full, 1=obs_mask, 2=aug_mask")
parser.add_argument("--breaks", type=int, default=30000)
parser.add_argument("--beta", type=str, default="0p7")
parser.add_argument("--use_log", type=str, default="switch")


def main(args):
    if args.cfg_path is None:
        args.cfg_path = os.path.join("configs", args.model, "default.yaml")
    config = OmegaConf.load(args.cfg_path)

    dataname = args.data
    noise_p = args.p
    noise_pattern = args.pattern
    noise_cov = args.cov
    noise_seed = args.noise_seed
    preproc = args.preproc
    beta = args.beta
    use_log = args.use_log
    device = f"cuda:{args.gpu}"
    extra = args.extra
    strategy = args.strategy

    if args.strategy == 1:
        preproc = "m"
        args.preproc = "m"

    run_name = f"{args.model}-{extra}_{args.preproc}_{noise_pattern}_{noise_p}_{noise_seed}"
    data_root = os.path.join(os.path.dirname(os.path.dirname(os.path.abspath(__file__))), "data")
    synthetic_dir = os.path.join(data_root, "synthetic", f"{args.model}-{extra}")
    os.makedirs(synthetic_dir, exist_ok=True)
    os.makedirs(os.path.join(synthetic_dir, dataname), exist_ok=True)

    if noise_p == 0:
        data_path = os.path.join(data_root, dataname)
        sample_path = os.path.join(synthetic_dir, dataname, "clean.csv")
    else:
        dt = "sup_x" if noise_cov == "x_only" else "semi_y" if noise_cov == "y_only" else "semi_xy" if noise_cov == "both" else "clean"
        data_path = os.path.join(data_root, f"{dataname}_{dt}", f"{noise_pattern}_p{noise_p}_{noise_seed}")
        if preproc == "LGB_S":
            sample_path = os.path.join(synthetic_dir, dataname, f"{dt}_{noise_pattern}_p{noise_p}_{noise_seed}{preproc}_beta{beta}_use_log{use_log}.csv")
        else:
            sample_path = os.path.join(synthetic_dir, dataname, f"{dt}_{noise_pattern}_p{noise_p}_{noise_seed}{preproc}.csv")

    if args.model == "cdtd":
        experiment = Experiment_CDTD(
            data_path, sample_path, config, args.exp_path or "", args.data, device, preproc,
            strategy=strategy, run_name=run_name, beta=beta, use_log=use_log,
        )
    elif args.model == "tabdiff":
        experiment = Experiment_TabDiff(
            data_path, sample_path, config, args.exp_path or "tabdiff", args.data, device, preproc,
            strategy=strategy, run_name=run_name, beta=beta, use_log=use_log,
        )
    elif args.model == "tabddpm":
        experiment = Experiment_TabDDPM(
            data_path, sample_path, config, args.exp_path or "tabddpm", args.data, device, preproc,
            strategy=strategy, run_name=run_name, beta=beta, use_log=use_log,
        )
    else:
        raise ValueError(f"Unknown model: {args.model}. Use cdtd, tabdiff, or tabddpm.")

    if args.mode == "train":
        experiment.train(save_model=True, plot_figures=(args.model == "cdtd"))
        if noise_p > 0:
            if preproc == "LGB_S":
                sample_path = os.path.join(synthetic_dir, dataname, f"{dt}_{noise_pattern}_p{noise_p}_{noise_seed}{preproc}_beta{beta}_use_log{use_log}.csv")
            else:
                sample_path = os.path.join(synthetic_dir, dataname, f"{dt}_{noise_pattern}_p{noise_p}_{noise_seed}{preproc}.csv")
            experiment.sample_path = sample_path
        experiment.evaluate_generative_model()
    elif args.mode == "eval":
        experiment.evaluate_generative_model()
    else:
        raise ValueError(f"Unknown mode: {args.mode}")


if __name__ == "__main__":
    main(parser.parse_args())
