from emm.data_gen.mixture.data_generator import (
    DataSuitGenerator,
    create_component_scaling_experiment,
    create_distnoiseY_experiment,
    create_overlap_experiment,
    create_noise_features_experiment,
    ComponentGeneratorRegistry,
    ExperimentConfig,
)
import os
import argparse


def generate_test_suit(output_dir="data"):
    scaling5 = create_component_scaling_experiment(
        name="scaling5",
        generator_name="tree_based",
        n_samples=600,
        n_features=5,
        component_values=[2, 4, 10, 15, 20],
        n_replications=5,
        min_leaf_size_fraction=0.2,
        empty_leaf_probability=0.5,
    )
    scaling20 = create_component_scaling_experiment(
        name="scaling20",
        generator_name="tree_based",
        n_samples=600,
        n_features=20,
        component_values=[2, 4, 10, 15, 20],
        n_replications=5,
        min_leaf_size_fraction=0.2,
        empty_leaf_probability=0.5,
    )

    noiseY = create_distnoiseY_experiment(
        name="noiseY",
        generator_name="tree_based",
        n_samples=3000,
        n_features=5,
        n_components=5,
        noise_std_values=[0, 0.05, 0.1, 0.2, 0.5, 1.0],
        n_replications=4,
        min_leaf_size_fraction=0.2,
        empty_leaf_probability=0.5,
    )
    noise_features = create_noise_features_experiment(
        name="noise_features",
        generator_name="tree_based",
        n_samples=3000,
        n_features=5,
        n_components=5,
        noise_values=[0, 2, 5, 10, 20],
        n_replications=5,
        min_leaf_size_fraction=0.2,
        empty_leaf_probability=0.5,
    )
    overlap = create_overlap_experiment(
        name="overlap",
        generator_name="tree_based",
        n_samples=3000,
        n_features=5,
        n_components=5,
        overlap_values=[0.01, 0.1, 0.2, 0.3, 0.5],
        n_replications=5,
        min_leaf_size_fraction=0.2,
    )

    runtime_sample_scaling = ExperimentConfig(
        name="runtime_sample_scaling",
        description="Vary number of samples to test runtime scaling.",
        base_params={
            "n_samples": 10000,
            "n_features": 5,
            "n_components": 5,
            "distribution_overlap": 0.1,
            "distributions": ["normal", "gamma", "uniform", "exponential"],
            "empty_leaf_probability": 0.5,
            "use_n_samples_per_component": False,
        },
        variable_param="n_samples",
        variable_values=[200, 500, 2000, 10000, 50000],
        generator_config=ComponentGeneratorRegistry().get("tree_based"),
        n_replications=3,
    )
    runtime_feature_scaling = ExperimentConfig(
        name="runtime_feature_scaling",
        description="Vary number of features to test runtime scaling.",
        base_params={
            "n_samples": 3000,
            "n_features": 2,
            "n_components": 5,
            "distribution_overlap": 0.1,
            "distributions": ["normal", "gamma", "uniform", "exponential"],
            "empty_leaf_probability": 0.5,
            "use_n_samples_per_component": False,
        },
        variable_param="n_features",
        variable_values=[2, 5, 10, 20, 100],
        generator_config=ComponentGeneratorRegistry().get("tree_based"),
        n_replications=2,
    )

    generator = DataSuitGenerator()
    generator.run_experiments(
        [
            scaling5,
            scaling20,
            noiseY,
            noise_features,
            overlap,
            runtime_sample_scaling,
            runtime_feature_scaling,
        ],
        output_dir,
    )


if __name__ == "__main__":
    parser = argparse.ArgumentParser(
        description="Generate test datasets for mixture models."
    )
    parser.add_argument(
        "--output_dir",
        type=str,
        default="data/synthetic",
        help="Directory to save generated datasets.",
    )
    args = parser.parse_args()
    if not os.path.exists(args.output_dir):
        os.makedirs(args.output_dir)
    generate_test_suit(output_dir=args.output_dir)
