from emm.evaluation.mixture_runner import run_test_suite, TrainingConfig
import argparse
import torch

if __name__ == "__main__":
    parser = argparse.ArgumentParser(description="Run synthetic evaluation for Emm")
    parser.add_argument(
        "--output_dir",
        type=str,
        default="results/synthetic",
        help="Directory to save results.",
    )
    parser.add_argument(
        "--data_dir",
        type=str,
        default="data/synthetic",
        help="Directory where synthetic datasets are located.",
    )
    parser.add_argument(
        "--method",
        type=str,
        nargs="+",
        default=["emm_gmm", "emm_nsf"],
        help="Methods to evaluate. Options: emm_gmm, emm_nsf",
    )
    parser.add_argument(
        "--experiment_names",
        type=str,
        nargs="+",
        default=[
            "scaling5",
            "scaling20",
            "noiseY",
            "noise_features",
            "overlap",
            "runtime_sample_scaling",
            "runtime_feature_scaling",
        ],
        help="Names of experiments to run.",
    )
    data_dir = parser.parse_args().data_dir
    output_dir = parser.parse_args().output_dir
    methods = parser.parse_args().method
    # use cuda if available
    device = "cuda" if torch.cuda.is_available() else "cpu"

    remix_config_large = TrainingConfig(
        use_model_finder=True,
        model_finder_component_range=[100],
        device=device,
        use_gmm_remix=True,
        n_gmm_components=30,
        component_scoring="bic",
        and_layer_entropy=0.005,
        partition_weight=0.1,
        min_responsibility_threshold=0.01,
        merge_components=True,
        merge_settle_epochs=40,
        check_responsibility_every=25,
    )
    remix_config = TrainingConfig(
        use_model_finder=True,
        model_finder_component_range=[10, 100],
        device=device,
        use_gmm_remix=True,
        n_gmm_components=30,
        component_scoring="bic",
        and_layer_entropy=0.005,
        partition_weight=0.1,
        min_responsibility_threshold=0.005,
        merge_components=True,
        merge_settle_epochs=40,
        check_responsibility_every=25,
    )

    nsf_config = TrainingConfig(
        use_model_finder=True,
        model_finder_component_range=[10, 100],
        device=device,
        flow_gen=("zuko_nsf", {}),
        use_gmm_remix=False,
        and_layer_entropy=0.005,
        partition_weight=0.1,
        min_responsibility_threshold=0.005,
        lr_flow=10e-3,
        merge_components=True,
        merge_settle_epochs=40,
        check_responsibility_every=25,
    )

    for experiment_name in parser.parse_args().experiment_names:
        if "emm_gmm" in methods:
            run_test_suite(
                configs=[remix_config_large]
                if experiment_name == "scaling20"
                else [remix_config],
                data_dir=f"{data_dir}/{experiment_name}",
                results_file=f"{output_dir}/emm_gmm_{experiment_name}.csv",
                overwrite=True,
            )
        if "emm_nsf" in methods:
            run_test_suite(
                configs=[nsf_config],
                data_dir=f"{data_dir}/{experiment_name}",
                results_file=f"{output_dir}/emm_nsf_{experiment_name}.csv",
                overwrite=True,
            )
