from methods.benchmark import TabularSynthesisBenchmark
import json
import os
import argparse

from data.utils import load_dataset

SMALL_DS = [
    "iris",
    "wine",
    "california",
    "parkinsons",
    "climate_model_crashes",
    "concrete_compression",
    "yacht_hydrodynamics",
    "airfoil_self_noise",
    "connectionist_bench_sonar",
    "ionosphere",
    "qsar_biodegradation",
    "seeds",
    "glass",
    "ecoli",
    "yeast",
    "libras",
    "planning_relax",
    "blood_transfusion",
    "breast_cancer_diagnostic",
    "connectionist_bench_vowel",
    "concrete_slump",
    "wine_quality_red",
    "wine_quality_white",
    "bean",
    "tictactoe",
    "congress",
    "car",
    # "higgs",
]
LARGE_DS = [
    "churn",
    "nmes",
    "lending",
    "adult",
    "default",
    "bank",
    "beijing",
    "news",
    "diabetes",
    "covertype",
    "acsincome",
]


def adjust_epochs(model_config, n, max_steps):
    """
    Adjust epochs in model_config to enforce training_steps <= max_steps for deep generative models
    """

    if model_config["name"] not in ["tabddpm", "tvae", "ctgan", "tabsyn"]:
        return model_config

    if model_config["name"] in ["tabddpm", "tvae", "ctgan"]:
        batch_size = model_config["params"]["batch_size"]
        epochs = model_config["params"]["epochs"]
        if model_config["name"] == "ctgan":
            # ctgan epochs must be divisible by pac parameter
            epochs = (
                int(epochs / model_config["params"]["pac"])
                * model_config["params"]["pac"]
            )
    elif model_config["name"] in ["tabsyn"]:
        batch_size = model_config["params"]["diffusion_batch_size"]
        epochs = model_config["params"]["diffusion_num_epochs"]

    num_steps = n / batch_size * epochs
    if num_steps > max_steps:
        epochs = int(max_steps / (n / batch_size))

    if model_config["name"] in ["tabddpm", "tvae", "ctgan"]:
        model_config["params"]["epochs"] = epochs
    elif model_config["name"] in ["tabsyn"]:
        model_config["params"]["diffusion_num_epochs"] = epochs
    return model_config


def parse_args():
    parser = argparse.ArgumentParser(
        description="Run tabular data generation benchmark"
    )

    parser.add_argument(
        "--generator",
        type=str,
        default="xgenboost_diffusion_xddpm",
        help="Name of the generator to use",
    )

    parser.add_argument(
        "--dataset",
        type=str,
        default="iris",
        help="Dataset name to use",
    )
    parser.add_argument(
        "--n_random_splits",
        type=int,
        default=1,
        help="Number of random splits for evaluation",
    )
    parser.add_argument(
        "--n_inits", type=int, default=1, help="Number of initializations"
    )
    parser.add_argument(
        "--n_generated_datasets",
        type=int,
        default=20,
        help="Number of generated datasets",
    )

    parser.add_argument("--test_size", type=float, default=0.2, help="Test set ratio")
    parser.add_argument(
        "--val_size", type=float, default=0.2, help="Validation set ratio"
    )

    parser.add_argument(
        "--max_syn_size",
        type=int,
        default=int(50000),
        help="Maximum size of synthetic dataset for evaluation",
    )
    parser.add_argument(
        "--max_training_steps",
        type=int,
        default=int(30000),
        help="Maximum number of training steps for deep generative models",
    )
    parser.add_argument(
        "--random_state",
        type=int,
        default=42,
        help="seed",
    )

    return parser.parse_args()


if __name__ == "__main__":

    args = parse_args()

    datasets = args.dataset
    if datasets == "big":
        datasets = LARGE_DS
        partition = "big"
    elif datasets == "small":
        datasets = SMALL_DS
        partition = "small"
    else:
        partition = "big" if datasets in LARGE_DS else "small"
        datasets = [datasets]

    for dataset in datasets:
        model_config = json.load(open(f"configs/model/{partition}.json"))[
            args.generator
        ]
        data_config = json.load(open(f"configs/data/{partition}.json"))[dataset]
        target_column = data_config["target"]
        discrete_features = data_config["cat_features"]
        X = load_dataset(dataset, data_config)

        # ensure max 30 000 training steps for deep generative models
        training_size = int(len(X) * (1 - args.test_size - args.val_size))
        max_steps = args.max_training_steps
        model_config = adjust_epochs(model_config, training_size, max_steps)

        workspace = f"workspace_{dataset}_{args.generator}"

        benchmark = TabularSynthesisBenchmark(
            generator_name=model_config["name"],
            generator_params=model_config["params"],
            n_random_splits=args.n_random_splits,
            n_inits=args.n_inits,
            n_generated_datasets=args.n_generated_datasets,
            metrics=json.load(open("configs/metric_configs.json")),
            test_size=args.test_size,
            val_size=args.val_size,
            workspace=workspace,
            max_syn_size=args.max_syn_size,
            random_state=args.random_state,
        )

        results = benchmark.run(
            X, target_column, discrete_features, result_format="frame"
        )

        save_path = f"results/{partition}"

        os.makedirs(save_path, exist_ok=True)
        results.to_csv(f"{save_path}/{dataset}_{args.generator}.csv", index=False)
