from datetime import datetime
from itertools import product
import platform
import time
from random import Random
from sys import stdout
from configparser import ConfigParser

from binary_tree import build_binary_tree_with_facilities # build input instance
from binary_tree_brute_force import binary_tree_brute_force # brute-force solver
from binary_tree_dp_exact import binary_tree_dp_v2 as binary_tree_dp_exact # exact DP solver
from binary_tree_dp_pruning import binary_tree_dp_pruning # DP with pruning solver

from data import generate_random_facility_client_df
from metric_tree_embedding import build_binary_tree_from_df
from embedding_evaluation import compare_original_vs_binary_tree_distances
from df_brute_force import brute_force_capacitated_k_median

from datasets import get_dataset_df


def scalability(algo_type, seed, nn, tt, kk, no_of_trials, facility_prob, output,
                max_num_states=None):

    l_random = Random()
    l_random.seed(seed)

    # print input parameters
    output.write(f"# Scalability experiments\n")
    output.write(f"# Algorithm: {algo_type}\n")
    output.write(f"# Date: {datetime.now()}\n")
    output.write(f"# Machine: {platform.node()}\n")
    output.write(f"# Seed: {seed}\n")
    output.write(f"# Number of nodes (n): {nn}\n")
    output.write(f"# Number of groups (t): {tt}\n")
    output.write(f"# Number of centers (k): {kk}\n")
    output.write(f"# Number of trials per configuration: {no_of_trials}\n")
    output.write(f"# Facility probability: {facility_prob}\n")
    output.write(f"# Max number of states (for DP pruning): {max_num_states}\n")
    output.write(f"#\n")
    output.flush()

    header = (
        f"{'seed':>10s} "
        f"{'n':>4s} {'t':>3s} {'k':>3s} {'rep':>3s} "
        f"{'input':>8s} "
        f"{algo_type:>10s}"
        f"{'cost':>10s}"
        f"{'iters':>10s}"
    )
    if algo_type == "DP_pruning":
        header += f"{' max_states':>8s} "
    output.write(header + "\n")
    output.write("#" + "=" * len(header) + "\n")

    for n, t, k in product(nn, tt, kk):
        for trial in range(no_of_trials):
            input_seed = l_random.randint(1, 123456789)

            output.write(f"{input_seed:10d} {n:4d} {t:3d} {k:3d} {trial:3d} ")
            output.flush()

            alpha = (1,) * t
            beta = (k,) * t
            max_capacity = int(n / k)

            start_time = time.time()
            # Build random instance on a tree metric.
            root = build_binary_tree_with_facilities(
                n, t, facility_prob, max_capacity, input_seed
            )
            input_time = time.time() - start_time

            if algo_type == "Brute_force":
                bf_start = time.time()
                bf_cost, bf_combination, bf_assignment = binary_tree_brute_force(root, k, alpha, beta)
                bf_time = time.time() - bf_start

                output.write(
                    f"{input_time:8.2f} "
                    f"{bf_time:10.2f} "
                    f"{bf_cost:10.2f} \n"
                )
            elif algo_type == "DP_exact":
                dp_exact_start = time.time()
                dp_exact_cost, dp_exact_iters = binary_tree_dp_exact(root, n, t, k, alpha, beta)
                dp_exact_time = time.time() - dp_exact_start

                output.write(
                    f"{input_time:8.2f} "
                    f"{dp_exact_time:10.2f} "
                    f"{dp_exact_cost:10.2f} "
                    f"{dp_exact_iters:10d} \n"
                )
            elif algo_type == "DP_pruning":
                if max_num_states == -1:
                    # max_states = 4 * n * t * k
                    max_states = 6 * n * t * k
                else:
                    max_states = max_num_states
                    
                dp_pruning_start = time.time()
                cost, group_counts, num_centers, iters = binary_tree_dp_pruning(
                        root, n, t, k, alpha, beta, max_states
                        )
                dp_pruning_time = time.time() - dp_pruning_start
                output.write(
                    f"{input_time:8.2f} "
                    f"{dp_pruning_time:10.2f} "
                    f"{cost:10.2f} "
                    f"{iters:10d} "
                    f"{max_states:8d} \n"
                )
            else:
                raise ValueError(f"Unknown algorithm type: {algo_type}")
            output.flush()

            total_time = time.time() - start_time
    output.write("=" * len(header) + "\n")
    output.flush()


######################################################################
def scalability_embedding(input_type, 
                          clustering_method, 
                          max_swaps, 
                          seed, 
                          nn, 
                          tt, 
                          kk, 
                          no_of_features, 
                          no_of_trials,
                          facility_prob, 
                          output):
    l_random = Random()
    l_random.seed(seed)

    # print input parameters
    output.write(f"# Scalability experiments\n")
    output.write(f"# Input type: {input_type}\n")
    output.write(f"# Clustering_method: {clustering_method}\n")
    output.write(f"# Max swaps: {max_swaps}\n")
    output.write(f"# Date: {datetime.now()}\n")
    output.write(f"# Machine: {platform.node()}\n")
    output.write(f"# Seed: {seed}\n")
    output.write(f"# Number of nodes (n): {nn}\n")
    output.write(f"# Number of groups (t): {tt}\n")
    output.write(f"# Number of centers (k): {kk}\n")
    output.write(f"# Number of trials per configuration: {no_of_trials}\n")
    output.write(f"# Number of features (for embedding): {no_of_features}\n")
    output.write(f"# Facility probability: {facility_prob}\n")
    output.write(f"#\n")
    output.flush()

    header = (
        f"{'seed':>10s} "
        f"{'n':>4s} {'t':>3s} {'k':>3s} {'rep':>3s} "
        f"{'df':>8s} "
        f"{'embed':>8s} "
        f"{'nodes':>6s} "
        f"{'max_diff':>10s} "
        f"{'avg_orig':>10s} "
        f"{'avg_bin':>10s} "
        f"{'diff_avg':>10s} "
    )
    output.write(header + "\n")
    output.write("#" + "=" * len(header) + "\n")

    for n, t, k in product(nn, tt, kk):
        for trial in range(no_of_trials):
            input_seed = l_random.randint(1, 123456789)
            local_search_seed = l_random.randint(1, 123456789)

            output.write(f"{input_seed:10d} {n:4d} {t:3d} {k:3d} {trial:3d} ")
            output.flush()

            max_capacity = int(n / k)

            start_time = time.time()
            df = generate_random_facility_client_df(n_points = n,
                                                    n_features=no_of_features,
                                                    n_groups=t,
                                                    facility_probability=facility_prob,
                                                    max_capacity=max_capacity,
                                                    seed=input_seed) 
            df_time = time.time() - start_time
            
            # Build random instance on a tree metric.
            embed_start = time.time()
            root, nodes, idx, stats = compare_original_vs_binary_tree_distances(df,
                                                              num_clusters=k,
                                                              clustering_method=clustering_method,
                                                              max_swaps=max_swaps,
                                                              feature_cols=[f'f{i+1}' for i in
                                                                            range(no_of_features)],
                                                              group_cols=[f'group{i+1}' for i in
                                                                          range(t)],
                                                              is_facility_col='is_facility',
                                                              capacity_col='capacity',
                                                              local_search_seed=local_search_seed,
                                                              compute_stats=True
                                                            )
            embed_time = time.time() - embed_start
            output.write(
                f"{df_time:8.2f} "
                f"{embed_time:8.2f} "
                f"{stats['num_nodes_in_binary_tree']:6d} "
                f"{stats['max_abs_diff']:10.3f} "
                f"{stats['mean_orig_dist']:10.3f} "
                f"{stats['mean_bin_dist']:10.3f} "
                # f"{stats['mean_signed_diff']:10.3f} "
                # f"{stats['mean_abs_diff']:10.3f} "
                # f"{stats['std_signed_diff']:10.3f} "
                f"{stats['difference_of_means']:10.3f} "
                f"\n"
                )
            output.flush()

            total_time = time.time() - start_time
    output.write("=" * len(header) + "\n")
    output.flush()

def scalability_embedding_plus_DP(input_type, 
                          clustering_method, 
                          max_swaps, 
                          seed, 
                          nn, 
                          tt, 
                          kk, 
                          no_of_features, 
                          no_of_trials,
                          facility_prob, 
                          output):
    l_random = Random()
    l_random.seed(seed)

    # print input parameters
    output.write(f"# Scalability experiments\n")
    output.write(f"# Input type: {input_type}\n")
    output.write(f"# Clustering_method: {clustering_method}\n")
    output.write(f"# Max swaps: {max_swaps}\n")
    output.write(f"# Date: {datetime.now()}\n")
    output.write(f"# Machine: {platform.node()}\n")
    output.write(f"# Seed: {seed}\n")
    output.write(f"# Number of nodes (n): {nn}\n")
    output.write(f"# Number of groups (t): {tt}\n")
    output.write(f"# Number of centers (k): {kk}\n")
    output.write(f"# Number of trials per configuration: {no_of_trials}\n")
    output.write(f"# Number of features (for embedding): {no_of_features}\n")
    output.write(f"# Facility probability: {facility_prob}\n")
    output.write(f"#\n")
    output.flush()

    header = (
        f"{'seed':>10s} "
        f"{'n':>4s} {'t':>3s} {'k':>3s} {'rep':>3s} "
        f"{'npp':>4s} "
        f"{'df':>5s} "
        f"{'embed':>6s} "
        f"{'avg_orig':>8s} "
        f"{'avg_bin':>8s} "
        f"{'DP':>8s} "
        f"{'DPcost':>8s} "
        f"{'BF':>8s} "
        f"{'BFcost':>8s} "
    )
    output.write(header + "\n")
    output.write("#" + "=" * len(header) + "\n")

    for n, t, k in product(nn, tt, kk):
        alpha = (1,) * t
        beta = (k,) * t
        for trial in range(no_of_trials):
            input_seed = l_random.randint(1, 123456789)
            local_search_seed = l_random.randint(1, 123456789)

            output.write(f"{input_seed:10d} {n:4d} {t:3d} {k:3d} {trial:3d} ")
            output.flush()

            max_capacity = int(n / k)

            start_time = time.time()
            df = generate_random_facility_client_df(n_points = n,
                                                    n_features=no_of_features,
                                                    n_groups=t,
                                                    facility_probability=facility_prob,
                                                    max_capacity=max_capacity,
                                                    seed=input_seed) 
            df_time = time.time() - start_time
            
            # Build random instance on a tree metric.
            embed_start = time.time()
            root, nodes, idx, stats = compare_original_vs_binary_tree_distances(df,
                                                              num_clusters=k,
                                                              clustering_method=clustering_method,
                                                              max_swaps=max_swaps,
                                                              feature_cols=[f'f{i+1}' for i in
                                                                            range(no_of_features)],
                                                              group_cols=[f'group{i+1}' for i in
                                                                          range(t)],
                                                              is_facility_col='is_facility',
                                                              capacity_col='capacity',
                                                              local_search_seed=local_search_seed,
                                                              compute_stats=True
                                                            )
            embed_time = time.time() - embed_start

            dp_start = time.time()
            dp_cost, dp_iters = binary_tree_dp_exact(root, n, t, k, alpha, beta)
            dp_time = time.time() - dp_start

            bf_start = time.time()
            if n < 100:
                result = brute_force_capacitated_k_median(df=df,
                                                      k=k,
                                                      alpha=alpha,
                                                      beta=beta,
                                                      feature_cols=[f'f{i+1}' for i in
                                                                    range(no_of_features)],
                                                     )
            bf_time = time.time() - bf_start

            output.write(
                f"{len(nodes):4d} "
                f"{df_time:5.2f} "
                f"{embed_time:6.2f} "
                f"{stats['mean_orig_dist']:8.3f} "
                f"{stats['mean_bin_dist']:8.3f} "
                f"{dp_time:8.2f} "
                f"{dp_cost:8.2f} "
                )
            if n < 100:
                output.write(
                    f"{bf_time:8.2f} "
                    f"{result['best_cost']:8.2f} "
                    f"\n"
                    )
            else:
                output.write(
                    f"{'na':>8s} "
                    f"{'na':>8s} "
                    f"\n"
                    )
            output.flush()

            total_time = time.time() - start_time
    output.write("=" * len(header) + "\n")
    output.flush()


######################################################################
# Brute-force scalability experiments

def scalability_brute_force_nn():
    algo_type = "Brute_force"
    seed = 123456789
    nn = [30, 40, 60, 80]
    tt = [3]
    kk = [5]
    no_of_trials  = 5
    facility_prob = 0.5

    # read results directory from config file
    config = ConfigParser()
    config.read("config.ini")
    results_dir = config.get("PATH", "results")

    # add date time stamp to filename to avoid overwriting
    filename = results_dir + f"scalability_brute_force_nn_{datetime.now().strftime('%d-%m-%Y_%H-%M-%S')}.txt"
    output = open(filename, "w")
    scalability(algo_type, seed, nn, tt, kk, no_of_trials, facility_prob, output)
    output.close()

def scalability_brute_force_kk():
    algo_type = "Brute_force"
    seed = 123456789
    nn = [255]
    tt = [3]
    kk = [3, 4, 5, 6]
    no_of_trials  = 5
    facility_prob = 0.5

    # read results directory from config file
    config = ConfigParser()
    config.read("config.ini")
    results_dir = config.get("PATH", "results")

    # add date time stamp to filename to avoid overwriting
    filename = results_dir + f"scalability_brute_force_kk_{datetime.now().strftime('%d-%m-%Y_%H-%M-%S')}.txt"
    output = open(filename, "w")
    scalability(algo_type, seed, nn, tt, kk, no_of_trials, facility_prob, output)
    output.close()

def scalability_brute_force_tt():
    algo_type = "Brute_force"
    seed = 123456789
    nn = [127]
    tt = [3, 4, 5]
    kk = [5]
    no_of_trials  = 5
    facility_prob = 0.5

    # read results directory from config file
    config = ConfigParser()
    config.read("config.ini")
    results_dir = config.get("PATH", "results")

    # add date time stamp to filename to avoid overwriting
    filename = results_dir + f"scalability_brute_force_tt_{datetime.now().strftime('%d-%m-%Y_%H-%M-%S')}.txt"
    output = open(filename, "w")
    scalability(algo_type, seed, nn, tt, kk, no_of_trials, facility_prob, output)
    output.close()


######################################################################
def scalability_dp_exact_nn():
    algo_type = "DP_exact"
    seed = 123456789
    nn = [63, 127, 255, 511, 1023]
    tt = [3]
    kk = [6]
    no_of_trials  = 5
    facility_prob = 0.5

    # read results directory from config file
    config = ConfigParser()
    config.read("config.ini")
    results_dir = config.get("PATH", "results")

    # add date time stamp to filename to avoid overwriting
    filename = results_dir + f"scalability_dp_exact_nn_{datetime.now().strftime('%d-%m-%Y_%H-%M-%S')}.txt"
    output = open(filename, "w")
    scalability(algo_type, seed, nn, tt, kk, no_of_trials, facility_prob, output)
    output.close()

def scalability_dp_exact_kk():
    algo_type = "DP_exact"
    seed = 123456789
    nn = [255]
    tt = [4]
    kk = [3, 4, 5, 6, 7, 8, 9, 10]
    no_of_trials  = 5
    facility_prob = 0.5

    # read results directory from config file
    config = ConfigParser()
    config.read("config.ini")
    results_dir = config.get("PATH", "results")

    # add date time stamp to filename to avoid overwriting
    filename = results_dir + f"scalability_dp_exact_kk_{datetime.now().strftime('%d-%m-%Y_%H-%M-%S')}.txt"
    output = open(filename, "w")
    scalability(algo_type, seed, nn, tt, kk, no_of_trials, facility_prob, output)
    output.close()

def scalability_dp_exact_tt():
    algo_type = "DP_exact"
    seed = 123456789
    nn = [128]
    tt = [3, 4, 5, 6, 7]
    kk = [5]
    no_of_trials  = 5
    facility_prob = 0.5

    # read results directory from config file
    config = ConfigParser()
    config.read("config.ini")
    results_dir = config.get("PATH", "results")

    # add date time stamp to filename to avoid overwriting
    filename = results_dir + f"scalability_dp_exact_tt_{datetime.now().strftime('%d-%m-%Y_%H-%M-%S')}.txt"
    output = open(filename, "w")
    scalability(algo_type, seed, nn, tt, kk, no_of_trials, facility_prob, output)
    output.close()

#####################################################################
def scalability_dp_pruning_nn():
    algo_type = "DP_pruning"
    seed = 123456789
    nn = [63, 127, 255, 511, 1023]
    tt = [3]
    kk = [6]
    no_of_trials  = 5
    facility_prob = 0.5
    max_num_states = -1

    # read results directory from config file
    config = ConfigParser()
    config.read("config.ini")
    results_dir = config.get("PATH", "results")

    # add date time stamp to filename to avoid overwriting
    filename = results_dir + f"scalability_dp_pruning_nn_st{max_num_states}_{datetime.now().strftime('%d-%m-%Y_%H-%M-%S')}.txt"
    output = open(filename, "w")
    scalability(algo_type, seed, nn, tt, kk, no_of_trials, facility_prob, output, max_num_states)
    output.close()


#####################################################################
def scalability_tree_embedding():
    input_type = "dataframe"
    clustering_method = "k-median"
    seed = 123456789
    nn = [30, 40, 60, 80, 100, 200, 400, 800, 1600, 3200]
    tt = [3]
    kk = [5]
    no_of_trials   = 5
    no_of_features = 10
    facility_prob = 0.5
    max_swaps = 10000

    # read results directory from config file
    config = ConfigParser()
    config.read("config.ini")
    results_dir = config.get("PATH", "results")

    # add date time stamp to filename to avoid overwriting
    filename = results_dir + f"scalability_tree_embedding_nn_{datetime.now().strftime('%d-%m-%Y_%H-%M-%S')}.txt"
    output = open(filename, "w")
    scalability_embedding(input_type, 
                          clustering_method, 
                          max_swaps,
                          seed, 
                          nn, 
                          tt, 
                          kk, 
                          no_of_features,
                          no_of_trials, 
                          facility_prob, 
                          output)
    output.close()

def scalability_tree_embedding_plus_DP_nn():
    input_type = "dataframe"
    clustering_method = "k-median"
    seed = 123456789
    # nn = [20, 30, 40, 60, 80, 100, 200, 400, 800, 1600]
    nn = [20, 30, 40, 60, 80]
    tt = [3]
    kk = [5]
    no_of_trials   = 5
    no_of_features = 10
    facility_prob  = 0.5
    max_swaps = 10000

    # read results directory from config file
    config = ConfigParser()
    config.read("config.ini")
    results_dir = config.get("PATH", "results")

    # add date time stamp to filename to avoid overwriting
    filename = results_dir + f"scalability_tree_embedding_plus_DP_nn_{datetime.now().strftime('%d-%m-%Y_%H-%M-%S')}.txt"
    output = open(filename, "w")
    scalability_embedding_plus_DP(input_type, 
                          clustering_method, 
                          max_swaps,
                          seed, 
                          nn, 
                          tt, 
                          kk, 
                          no_of_features,
                          no_of_trials, 
                          facility_prob, 
                          output)
    output.close()

######################################################################
def realworld_data(input_type, 
                          clustering_method, 
                          max_swaps,
                          datasets,
                          tt,
                          kk,
                          seed, 
                          facility_prob, 
                          output):
    l_random = Random()
    l_random.seed(seed)

    # print input parameters
    output.write(f"# Scalability experiments\n")
    output.write(f"# Input type: {input_type}\n")
    output.write(f"# Clustering_method: {clustering_method}\n")
    output.write(f"# Max swaps: {max_swaps}\n")
    output.write(f"# Date: {datetime.now()}\n")
    output.write(f"# Machine: {platform.node()}\n")
    output.write(f"# Seed: {seed}\n")
    output.write(f"# Datasets: {datasets}\n")
    output.write(f"# Number of groups (t): {tt}\n")
    output.write(f"# Number of centers (k): {kk}\n")
    output.write(f"# Facility probability: {facility_prob}\n")
    output.write(f"#\n")
    output.flush()

    header = (
        f"{'seed':>10s} "
        f"{'dataset':>12s} {'n':>5s} {'t':>3s} {'k':>3s} "
        f"{'npp':>6s} "
        f"{'n_f':>6s} "
        f"{'df':>4s} "
        f"{'embed':>4s} "
        f"{'DP':>8s} "
        f"{'max_diff':>8s} "
        f"{'avg_orig':>8s} "
        f"{'avg_bin':>8s} "
        f"{'diff_avg':>8s} "
        f"{'DPcost':>8s} "
    )
    output.write(header + "\n")
    output.write("#" + "=" * len(header) + "\n")

    for dataset, t, k in product(datasets, tt, kk):
        input_seed = 123456789
        local_search_seed = l_random.randint(1, 123456789)


        start_time = time.time()
        df, df_stats = get_dataset_df(n_groups=t, 
                                      dataset_name=dataset,
                                      max_capacity=-1,
                                      facility_probability=facility_prob,
                                      k = k,
                                      seed = seed,
                                      write_to_file=False)
        df_time = time.time() - start_time

        n = df_stats['num_points']
        no_of_features = df_stats['num_features']
        alpha = (1,) * t
        beta = (k,) * t

        output.write(f"{input_seed:10d} {dataset:>12s} {n:5d} {t:3d} {k:3d} ")
        output.flush()
        
        # Build random instance on a tree metric.
        embed_start = time.time()
        root, nodes, idx, stats = compare_original_vs_binary_tree_distances(df,
                                                          num_clusters=k,
                                                          clustering_method=clustering_method,
                                                          max_swaps=max_swaps,
                                                          feature_cols=[f'f{i+1}' for i in
                                                                        range(no_of_features)],
                                                          group_cols=[f'group{i+1}' for i in
                                                                      range(t)],
                                                          is_facility_col='is_facility',
                                                          capacity_col='capacity',
                                                          local_search_seed=local_search_seed,
                                                          compute_stats=True
                                                        )
        embed_time = time.time() - embed_start

        output.write(
            f"{stats['num_nodes_in_binary_tree']:6d} "
            f"{df_stats['num_facilities']:6d} "
            f"{df_time:4.2f} "
            f"{embed_time:5.2f} "
            )
        output.flush()

        dp_start = time.time()
        dp_cost, dp_iters = binary_tree_dp_exact(root, n, t, k, alpha, beta)
        dp_time = time.time() - dp_start

        output.write(
            f"{dp_time:8.2f} "
            f"{stats['max_abs_diff']:8.3f} "
            f"{stats['mean_orig_dist']:8.3f} "
            f"{stats['mean_bin_dist']:8.3f} "
            # f"{stats['mean_signed_diff']:10.3f} "
            # f"{stats['mean_abs_diff']:10.3f} "
            # f"{stats['std_signed_diff']:10.3f} "
            f"{stats['difference_of_means']:8.3f} "
            f"{dp_cost:8.3f} "
            f"\n"
            )
        output.flush()

        total_time = time.time() - start_time
    output.write("=" * len(header) + "\n")
    output.flush()

def scalability_realworld_data():
    input_type = "realworld"
    clustering_method = "k-median"
    seed = 123456789
    # datasets = ["heart", "student-mat", "student-por", "NPHA"]
    datasets = ["NPHA"]
    tt = [4]
    kk = [5, 6, 7, 8, 9, 10]
    facility_prob  = 0.3
    max_swaps = 10000

    # read results directory from config file
    config = ConfigParser()
    config.read("config.ini")
    results_dir = config.get("PATH", "results")

    # add date time stamp to filename to avoid overwriting
    filename = results_dir + f"scalability_real_world_data_t4_{datetime.now().strftime('%d-%m-%Y_%H-%M-%S')}.txt"
    output = open(filename, "w")
 
    realworld_data(input_type, 
                   clustering_method, 
                   max_swaps,
                   datasets,
                   tt,
                   kk,
                   seed, 
                   facility_prob, 
                   output)
    output.close()


#####################################################################
if __name__ == "__main__":
    # Lightweight CLI wrapper so reviewers can run *only* the experiments they want.
    #
    # Examples:
    #   python scalability.py --list
    #   python scalability.py --experiment dp_exact_nn
    #   python scalability.py --experiment dp_pruning_nn
    #   python scalability.py --experiment tree_embedding
    #   python scalability.py --experiment realworld
    #   python scalability.py --all
    #
    import argparse

    experiments = {
        "brute_force_nn": scalability_brute_force_nn,
        "brute_force_kk": scalability_brute_force_kk,
        "brute_force_tt": scalability_brute_force_tt,
        "dp_exact_nn": scalability_dp_exact_nn,
        "dp_exact_kk": scalability_dp_exact_kk,
        "dp_exact_tt": scalability_dp_exact_tt,
        "dp_pruning_nn": scalability_dp_pruning_nn,
        "tree_embedding": scalability_tree_embedding,
        "tree_embedding_plus_dp_nn": scalability_tree_embedding_plus_DP_nn,
        "realworld": scalability_realworld_data,
    }

    parser = argparse.ArgumentParser(
        description="Run scalability experiments and write results to ../results/ (see config.ini)."
    )
    parser.add_argument("--list", action="store_true", help="List available experiments and exit.")
    parser.add_argument("--experiment", choices=sorted(experiments.keys()), help="Run a single experiment.")
    parser.add_argument("--all", action="store_true", help="Run the full suite (may take a long time).")

    args = parser.parse_args()

    if args.list:
        for name in sorted(experiments.keys()):
            print(name)
        raise SystemExit(0)

    if args.all:
        for name in [
            "brute_force_nn",
            # "brute_force_kk",
            # "brute_force_tt",
            "dp_exact_nn",
            "dp_exact_kk",
            "dp_exact_tt",
            "dp_pruning_nn",
            "tree_embedding",
            "tree_embedding_plus_dp_nn",
            "realworld",
        ]:
            print(f"Running {name}...")
            experiments[name]()
        raise SystemExit(0)

    if args.experiment is None:
        parser.print_help()
        raise SystemExit(2)

    experiments[args.experiment]()
