# Create figures to include in the paper
import matplotlib.pyplot as plt
import pandas as pd
import numpy as np
import math

plt.rcParams.update({
    "text.usetex": True,
})

sim_alg_names = [
    "fc",
    "knn",
    "dynamicCPS",
]
sim_alg_colors = {
    "fc": 'black',
    "knn": 'blue',
    "dynamicCPS": 'red',
}

# Store n, d, a, batch size, num_runs, reference, license, info
sim_datasets = {
    'blobs': (20000, 10, 0.1, 1000, 3, ['scikit-learn'], 'BSD',
              'Synthetic clusters from a \\\\ mixture of Gaussian distributions.'),
    'cifar10': (50000, 2048, 0.0001, 1000, 2, ['resnet', 'cifar'],
                '-',
                'ResNet-50 embeddings of images.'),
    'mnist': (70000, 728, 0.000001, 1000, 2, ['lecun_mnist_1998'], 'CC BY-SA 3.0',
              'Images of handwritten digits.'),
}

# For each dataset, store n, d, a, batch size, num_runs, reference, license, info
kde_datasets = {
    'shuttle': (58000, 9, 0.01, 1000, 5, ['shuttle_dataset'], "CC BY 4.0",
                'Numerical data from NASA \\\\ space shuttle sensors.'),
    # 'mnist': (70000, 728, 0.000001, 1000, 5, ['lecun_mnist_1998'], 'CC BY-SA 3.0',
    #           'Images of handwritten digits.'),
    'aloi': (108000, 128, 0.01, 1000, 5, ['aloi_dataset'],
             '-',
             'Images of objects under a variety \\\\ of lighting conditions.'),
    'msd': (515345, 90, 0.000001, 1000, 1, ['msd_uci', 'msd-main-ref'], 'CC BY 4.0',
            'Numerical and categorical \\\\ features of songs.'),
    'covtype': (581012, 54, 0.000005, 1000, 3, ['covtype_uci',
                                                'covtype_dataset'], 'CC BY 4.0',
                'Cartographic features used to predict \\\\ forest cover type.'),
    'glove': (1193514, 100, 0.1, 10000, 1, ['glove_dataset'], 'PDDL 1.0',
              'Word embedding vectors.'),
    'census': (2458285, 68, 0.01, 10000, 2, ['us_census_data'], 'CC BY 4.0',
               'Categorical and numerical data from \\\\ the 1990 US census.'),
}

datasets_in_extended_abstract = [
    "blobs",
    "mnist",
    "cifar10",
    "aloi",
    "msd",
    "covtype",
    "glove",
    "census",
]

sim_alg_latex_names = {
    "fc": "\\textsc{FullyConnected}",
    "knn": "\\textsc{kNN}",
    "dynamicCPS": "\\textsc{Our Algorithm}",
}

kde_alg_latex_names = {
    'DynamicExactKDE': '\\textsc{Exact}',
    'NaiveDynamicCKNS': '\\textsc{CKNS}',
    'DynamicRandomSamplingKDE': '\\textsc{DynamicRS}',
    'DynamicCKNS': '\\textsc{Our Algorithm}',
}

kde_alg_names = [
    "DynamicExactKDE",
    "NaiveDynamicCKNS",
    "DynamicRandomSamplingKDE",
    "DynamicCKNS",
]

kde_algs_in_table_full = [
    "NaiveDynamicCKNS",
    "DynamicRandomSamplingKDE",
    "DynamicCKNS",
]

kde_algs_in_table_time_only = [
    # "DynamicExactKDE",
]

kde_alg_filenames = {
    'DynamicExactKDE': 'exact',
    'NaiveDynamicCKNS': 'naiveCKNS',
    'DynamicRandomSamplingKDE': 'rs',
    'DynamicCKNS': 'dynamicCKNS',
}

kde_alg_colors = {
    'DynamicExactKDE': 'black',
    'NaiveDynamicCKNS': 'blue',
    'DynamicRandomSamplingKDE': 'green',
    'DynamicCKNS': 'red',
}

def get_kde_num_runs(dataset):
    _, _, _, _, num_runs, _, _, _ = kde_datasets[dataset]
    return num_runs


def get_sim_num_runs(dataset):
    _, _, _, _, num_runs, _, _, _ = sim_datasets[dataset]
    return num_runs


def get_kde_batch_size(dataset):
    _, _, _, batch_size, _, _, _, _ = kde_datasets[dataset]
    return batch_size


def get_sim_batch_size(dataset):
    _, _, _, batch_size, _, _, _, _ = sim_datasets[dataset]
    return batch_size


def get_kde_running_times_s(dataset):
    all_running_times = {alg: [] for alg in kde_alg_names}
    for alg in kde_alg_names:
        for run in range(get_kde_num_runs(dataset)):
            try:
                results_filename = f"../results/{dataset}_kde/{kde_alg_filenames[alg]}.0{run + 1}.csv"
                data = pd.read_csv(results_filename, skipinitialspace=True)
                all_running_times[alg].append(data.total_time.max() / 1000)
            except:
                print(f"Couldn't read data for dataset {dataset}, algorithm {alg}, run {run + 1}.")

    avg_running_times = {}
    two_sigma_errors = {}
    for alg in kde_alg_names:
        avg_running_times[alg] = np.mean(all_running_times[alg])
        sigma = np.std(all_running_times[alg])
        two_sigma_errors[alg] = 2 * sigma

    return avg_running_times, two_sigma_errors


def get_kde_final_value(dataset, col_name):
    all_values = {alg: [] for alg in kde_alg_names}

    for alg in kde_alg_names:
        for run in range(get_kde_num_runs(dataset)):
            try:
                results_filename = f"../results/{dataset}_kde/{kde_alg_filenames[alg]}.0{run + 1}.csv"
                data = pd.read_csv(results_filename, skipinitialspace=True)
                final_running_time = data.total_time.max()
                final_result = data[data['total_time'] == final_running_time]
                all_values[alg].append(final_result.iloc[0][col_name])
            except:
                print(f"Couldn't read data for dataset {dataset}, algorithm {alg}, run {run + 1}.")

    avg_values = {}
    two_sigma_errors = {}
    for alg in kde_alg_names:
        avg_values[alg] = np.mean(all_values[alg])
        sigma = np.std(all_values[alg])
        two_sigma_errors[alg] = 2 * sigma

    return avg_values, two_sigma_errors


def get_sim_running_times_s(dataset):
    all_running_times = {alg: [] for alg in sim_alg_names}
    for alg in sim_alg_names:
        for run in range(get_sim_num_runs(dataset)):
            try:
                results_filename = f"../results/{dataset}_sg/{alg}.0{run + 1}.csv"
                data = pd.read_csv(results_filename, skipinitialspace=True)
                all_running_times[alg].append(data.total_time.max() / 1000)
            except:
                print(f"Couldn't read data for dataset {dataset}, algorithm {alg}, run {run + 1}.")

    avg_running_times = {}
    two_sigma_errors = {}
    for alg in sim_alg_names:
        avg_running_times[alg] = np.mean(all_running_times[alg])
        sigma = np.std(all_running_times[alg])
        two_sigma_errors[alg] = 2 * sigma

    return avg_running_times, two_sigma_errors


def get_sim_final_values(dataset, column):
    all_values = {alg: [] for alg in sim_alg_names}

    for alg in sim_alg_names:
        for run in range(get_sim_num_runs(dataset)):
            try:
                results_filename = f"../results/{dataset}_sg/{alg}.0{run + 1}.csv"
                data = pd.read_csv(results_filename, skipinitialspace=True)
                final_running_time = data.total_time.max()
                final_result = data[data['total_time'] == final_running_time]
                all_values[alg].append(final_result.iloc[0][column])
            except:
                print(f"Couldn't read data for dataset {dataset}, algorithm {alg}, run {run + 1}.")

    avg_values = {}
    two_sigma_errors = {}
    for alg in sim_alg_names:
        avg_values[alg] = np.mean(all_values[alg])
        sigma = np.std(all_values[alg])
        two_sigma_errors[alg] = 2 * sigma

    return avg_values, two_sigma_errors


def create_dataset_size_info_table(dataset_info,
                                   table_filename,
                                   table_caption,
                                   table_label):
    table_preamble = """\\begin{wraptable}{r}{5.5cm}
  \\caption{""" + table_caption + """ \\label{""" + table_label + """}}
  \\centering
  \\begin{tabular}"""

    table_format_str = "{ccc}\n"
    table_header = "    \\toprule \n    Dataaset & $n$ & $d$ \\\\\n    \midrule\n"
    table_content = ""
    for dataset in datasets_in_extended_abstract:
        n, d, _, _, _, refs, license, info = dataset_info[dataset]
        table_content += f"    {dataset} & {n:,} & {d:,} "
        table_content += "\\\\\n"

    table_endamble = """    \\bottomrule
  \\end{tabular}
\\end{wraptable}"""

    table_str = table_preamble + table_format_str + table_header + table_content + table_endamble

    with open(table_filename, 'w') as fout:
        fout.write(table_str)


def create_dataset_full_info_table(dataset_info,
                                  table_filename,
                                  table_caption,
                                  table_label):
    table_preamble = """\\begin{table}
  \\caption{""" + table_caption + """}
  \\label{""" + table_label + """}
  \\centering
  \\resizebox{\\textwidth}{!}{\\begin{tabular}"""

    table_format_str = "{cccccc}\n"
    table_header = "    \\toprule \n    \\textbf{Dataaset} & $\\mathbf{n}$ & $\\mathbf{d}$ & \\textbf{License} & \\textbf{Reference} & \\textbf{Description} \\\\\n    \midrule\n"
    table_content = ""
    for i, dataset in enumerate(dataset_info.keys()):
        n, d, _, _, _, refs, license, info = dataset_info[dataset]
        table_content += f"    {dataset} & {n: ,} & {d: ,} & " + license + " & "
        if refs:
            table_content += "\cite" + "{" + ','.join(refs) + "}"
        table_content += " & \\begin{tabular}{@{}c@{}}" + info + "\\end{tabular}\\\\\n"
        if i != len(dataset_info.keys()) - 1:
            table_content += "\\midrule \n"


    table_endamble = """    \\bottomrule
  \\end{tabular}}
\\end{table}"""

    table_str = table_preamble + table_format_str + table_header + table_content + table_endamble

    with open(table_filename, 'w') as fout:
        fout.write(table_str)


def create_dataset_information_tables():
    create_dataset_full_info_table(sim_datasets | kde_datasets, "full_dataset_info_table.tex",
                                   "Datasets used for experimental evaluation.",
                                   "tab:full-dataset-info")
    create_dataset_size_info_table(sim_datasets | kde_datasets,
                                   "dataset_info_table.tex",
                                   "Dataset size information.",
                                   "tab:dataset-sizes")


def create_kde_results_table_ea():
    table_filename = "kde_results_table.tex"
    table_preamble = """\\begin{table} [t]
  \\caption{\kdecaption \\label{tab:dynamic_kde}}
  \\centering
  \\resizebox{\\textwidth}{!}{
  \\begin{tabular}"""

    table_content = ""
    table_format_str = None
    table_header = None
    for dataset in kde_datasets.keys():
        if dataset in datasets_in_extended_abstract:
            # Don't include mnist in the extended abstract
            if dataset == 'mnist':
                continue
            times, times_2sigma = get_kde_running_times_s(dataset)
            errors, errors_2sigma = get_kde_final_value(dataset, 'rel_err')

            best_time = min(times.values())

            table_content += "    " + dataset
            for alg in kde_algs_in_table_full:
                cellcolorstr = "\cellcolor{gray!25}" if times[alg] == best_time else ""
                table_content += " & " + cellcolorstr + f"${times[alg]: ,.1f}" + "{\\scriptstyle" + f" \pm {times_2sigma[alg]: .1f}" + "}$"
                if alg in kde_algs_in_table_full:
                    if errors[alg] > 0:
                        table_content += " & " + cellcolorstr + f"${errors[alg]: .3f}" + "{\\scriptstyle" + f" \pm {errors_2sigma[alg]: .3f}" + "}$"
                    else:
                        table_content += " & 0"
            table_content += " \\\\\n"

            if table_format_str is None:
                table_format_str = "{c" + \
                                   "".join(["cc" for x in kde_algs_in_table_full]) + "}\n    \\toprule\n"
            if table_header is None:
                table_header = " & " + " & ".join(
                        ["\multicolumn{2}{c}{" + f"{kde_alg_latex_names[alg]}" + "}" for alg in kde_algs_in_table_full]) + "\\\\ \n"
                table_header += "    " + \
                                " ".join(
                                    ["\cmidrule(lr){" + f"{(2) + (2 * i)}-{(3) + (2 * i)}" + "}" for i, alg in enumerate(kde_algs_in_table_full)]) + "\n"
                table_header += "    dataset & "
                table_header += " & ".join(
                                    ["Time (s) & Err" for _ in kde_algs_in_table_full]) + "\\\\ \n"
                table_header += "    \\midrule\n"

    table_endamble = """    \\bottomrule
  \\end{tabular}}
\\end{table}"""

    table_str = table_preamble + table_format_str + table_header + table_content + table_endamble

    with open(table_filename, 'w') as fout:
        fout.write(table_str)


def create_kde_results_table_app():
    table_filename = "kde_results_table_app.tex"
    table_preamble = """\\begin{table}
  \\caption{\kdeappcaption \\label{tab:dynamic_kde_app}}
  \\centering
  \\resizebox{\\textwidth}{!}{
  \\begin{tabular}"""

    table_content = ""
    table_format_str = None
    table_header = None
    for dataset in kde_datasets.keys():
        times, times_2sigma = get_kde_running_times_s(dataset)
        errors, errors_2sigma = get_kde_final_value(dataset, 'rel_err')

        best_time = min(times.values())

        table_content += "    " + dataset
        for alg in kde_algs_in_table_time_only + kde_algs_in_table_full:
            cellcolorstr = "\cellcolor{gray!25}" if times[alg] == best_time else ""
            table_content += " & " + cellcolorstr + f"${times[alg]: ,.1f}" + "{\\scriptstyle" + f" \pm {times_2sigma[alg]: .1f}" + "}$"
            if alg in kde_algs_in_table_full:
                if errors[alg] > 0:
                    table_content += " & " + cellcolorstr + f"${errors[alg]: .3f}" + "{\\scriptstyle" + f" \pm {errors_2sigma[alg]: .3f}" + "}$"
                else:
                    table_content += " & 0"
        table_content += " \\\\\n"

        if table_format_str is None:
            table_format_str = "{c" + "".join(["c" for x in kde_algs_in_table_time_only]) + \
                               "".join(["cc" for x in kde_algs_in_table_full]) + "}\n    \\toprule\n"
        if table_header is None:
            table_header = "    & " \
                           + " & ".join(
                [f"{kde_alg_latex_names[alg]}" for alg in kde_algs_in_table_time_only])
            if len(kde_algs_in_table_time_only) > 0:
                table_header += " & "
            table_header += " & ".join(
                ["\multicolumn{2}{c}{" + f"{kde_alg_latex_names[alg]}" + "}" for alg in kde_algs_in_table_full]) + "\\\\ \n"
            table_header += "    " + \
                            " ".join(
                                ["\cmidrule(lr){" + f"{2 + i}-{2 + i}" + "}" for i, alg in enumerate(kde_algs_in_table_time_only)]) + \
                            " ".join(
                                ["\cmidrule(lr){" + f"{(len(kde_algs_in_table_time_only) + 2) + (2 * i)}-{(len(kde_algs_in_table_time_only) + 3) + (2 * i)}" + "}" for i, alg in enumerate(kde_algs_in_table_full)]) + "\n"
            table_header += "    dataset & " + \
                            " & ".join(
                                ["Time (s)" for _ in kde_algs_in_table_time_only])
            if len(kde_algs_in_table_time_only) > 0:
                table_header += " & "
            table_header += " & ".join(
                ["Time (s) & Err" for _ in kde_algs_in_table_full]) + "\\\\ \n"
            table_header += "    \\midrule\n"

    table_endamble = """    \\bottomrule
  \\end{tabular}}
\\end{table}"""

    table_str = table_preamble + table_format_str + table_header + table_content + table_endamble

    with open(table_filename, 'w') as fout:
        fout.write(table_str)


def create_kde_exact_results_table():
    table_filename = "kde_exact_results_table.tex"
    table_preamble = """\\begin{table}
  \\caption{\kdeexactcaption \\label{tab:dynamic_kde_exact}}
  \\centering
  \\begin{tabular}"""

    table_content = ""
    table_format_str = None
    table_header = None
    for dataset in kde_datasets.keys():
        times, times_2sigma = get_kde_running_times_s(dataset)
        exact_time = times["DynamicExactKDE"]
        exact_error = times_2sigma["DynamicExactKDE"]

        table_content += "    " + dataset
        table_content += " & " + f"${exact_time: ,.1f}" + "{\\scriptstyle" + f" \pm {exact_error: .1f}" + "}$"
        table_content += " \\\\\n"

        if table_format_str is None:
            table_format_str = "{cc}"
        if table_header is None:
            table_header = "\\toprule\n    Dataset & Running Time \\\\\n \\midrule \n"

    table_endamble = """    \\bottomrule
  \\end{tabular}
\\end{table}"""

    table_str = table_preamble + table_format_str + table_header + table_content + table_endamble

    with open(table_filename, 'w') as fout:
        fout.write(table_str)




def get_kde_values_and_error(dataset: str, alg: str, column_name: str, divide_ms=False,
                             rmse=False):
    all_values = []

    for run in range(get_kde_num_runs(dataset)):
        results_filename = f"../results/{dataset}_kde/{kde_alg_filenames[alg]}.0{run+1}.csv"
        data = pd.read_csv(results_filename, skipinitialspace=True)
        if rmse:
            all_values.append([math.sqrt(x) for x in data[column_name]])
        else:
            all_values.append([x if not divide_ms else x / 1000 for x in data[column_name]])

    all_values = np.asarray(all_values)
    avg_values = np.mean(all_values, axis=0)
    two_sigma_errors = 2 * np.std(all_values, axis=0)
    return avg_values, two_sigma_errors


def get_sim_values_and_error(dataset: str, alg: str, column_name: str,
                             divide_ms=False):
    all_values = []

    for run in range(get_sim_num_runs(dataset)):
        try:
            results_filename = f"../results/{dataset}_sg/{alg}.0{run+1}.csv"
            data = pd.read_csv(results_filename, skipinitialspace=True)
            all_values.append([x if not divide_ms else x / 1000 for x in data[column_name]])
        except:
            print(f"Error reading results for {dataset}, algorithm {alg}, run {run+1}.")

    all_values = np.asarray(all_values)
    avg_values = np.mean(all_values, axis=0)
    two_sigma_errors = 2 * np.std(all_values, axis=0)
    return avg_values, two_sigma_errors


def create_one_kde_figure(dataset: str, ylabel: str, column_name: str, file_name: str, divide_ms=False,
                          rmse=False):
    # Plotting the function
    plt.figure(figsize=(4, 3))
    plt.xlabel("n")
    plt.ylabel(ylabel)
    plt.grid(True)

    for alg in kde_alg_names:
        try:
            values, errors = get_kde_values_and_error(dataset, alg, column_name, divide_ms=divide_ms,
                                                      rmse=rmse)
            ns = [i * get_kde_batch_size(dataset) for i in range(1, len(values) + 1)]
            plt.plot(ns,
                     values,
                     label=kde_alg_latex_names[alg],
                     linewidth=3, color=kde_alg_colors[alg])
            plt.fill_between(ns,
                             values-errors,
                             values+errors,
                             color=kde_alg_colors[alg],
                             alpha=0.3)
        except:
            print(f"Couldn't read data for dataset {dataset} and algorithm {alg}.")

    plt.gca().ticklabel_format(axis='x', style='sci', scilimits=(3,3))
    plt.legend()
    plt.savefig(f"{dataset}_{file_name}.pdf", format="pdf", bbox_inches="tight")
    plt.show()


def create_kde_dataset_figures(dataset: str):
    create_one_kde_figure(dataset, "Relative Error", "rel_err", "relative_error")
    create_one_kde_figure(dataset, "Total time (s)", "total_time", "time", divide_ms=True)
    create_one_kde_figure(dataset, "Update time (s)", "update_time", "update_time", divide_ms=True)


def create_sim_results_table_ea():
    table_filename = "sim_results_table.tex"
    table_preamble = """\\begin{table}[t]
  \\caption{\simcaption \\label{tab:dynamic_sg}}
  \\centering
  \\resizebox{\\textwidth}{!}{
  \\begin{tabular}"""

    table_content = ""
    table_format_str = None
    table_header = None
    for dataset in sim_datasets.keys():
        if dataset not in datasets_in_extended_abstract:
            continue
        times, time_errs = get_sim_running_times_s(dataset)
        aris, ari_errs = get_sim_final_values(dataset, 'ari')
        nmis, nmi_errs = get_sim_final_values(dataset, 'nmi')

        best_time = min(times.values())

        table_content += "    " + dataset
        for alg in times.keys():
            cellcolorstr = "\cellcolor{gray!25}" if times[alg] == best_time else ""
            table_content += " & " + cellcolorstr + "$" + f"{times[alg]: ,.1f}" + "{\\scriptstyle" + f" \pm {time_errs[alg]: .1f}" + "}$"
            table_content += " & " + cellcolorstr + f"${nmis[alg]: .3f}" + "{\\scriptstyle" + f" \pm {nmi_errs[alg]: .3f}" + "}$"
        table_content += " \\\\\n"

        if table_format_str is None:
            table_format_str = "{c" + "".join(["cc" for x in times.keys()]) + "}\n    \\toprule\n"
        if table_header is None:
            table_header = "    & " + " & ".join(
                ["\multicolumn{2}{c}{" + f"{sim_alg_latex_names[alg]}" + "}" for alg in times.keys()]) + "\\\\ \n"
            table_header += "    " + " ".join(
                ["\cmidrule(lr){" + f"{2 + (2 * i)}-{3 + (2 * i)}" + "}" for i, alg in enumerate(times.keys())]) + "\n"
            table_header += "    dataset & " + " & ".join(
                ["Time (s) & NMI" for _ in times.keys()]) + "\\\\ \n"
            table_header += "    \\midrule\n"

    table_endamble = """    \\bottomrule
  \\end{tabular}}
\\end{table}"""

    table_str = table_preamble + table_format_str + table_header + table_content + table_endamble

    with open(table_filename, 'w') as fout:
        fout.write(table_str)


def create_sim_results_table_app():
    table_filename = "sim_results_table_app.tex"
    table_preamble = """\\begin{table}
  \\caption{\simcaptionapp \\label{tab:dynamic_sg_app}}
  \\centering
  \\begin{tabular}"""

    table_content = ""
    table_format_str = None
    table_header = None
    for dataset in sim_datasets.keys():
        if dataset not in datasets_in_extended_abstract:
            continue
        times, time_errs = get_sim_running_times_s(dataset)
        aris, ari_errs = get_sim_final_values(dataset, 'ari')

        best_time = min(times.values())

        table_content += "    " + dataset
        for alg in times.keys():
            cellcolorstr = "\cellcolor{gray!25}" if times[alg] == best_time else ""
            table_content += " & " + cellcolorstr + "$" + f"{times[alg]: ,.1f}" + "{\\scriptstyle" + f" \pm {time_errs[alg]: .1f}" + "}$"
            table_content += " & " + cellcolorstr + f"${aris[alg]: .3f}" + "{\\scriptstyle" + f" \pm {ari_errs[alg]: .3f}" + "}$"
        table_content += " \\\\\n"

        if table_format_str is None:
            table_format_str = "{c" + "".join(["cc" for x in times.keys()]) + "}\n    \\toprule\n"
        if table_header is None:
            table_header = "    & " + " & ".join(
                ["\multicolumn{2}{c}{" + f"{sim_alg_latex_names[alg]}" + "}" for alg in times.keys()]) + "\\\\ \n"
            table_header += "    " + " ".join(
                ["\cmidrule(lr){" + f"{2 + (2 * i)}-{3 + (2 * i)}" + "}" for i, alg in enumerate(times.keys())]) + "\n"
            table_header += "    dataset & " + " & ".join(
                ["Time (s) & ARI" for _ in times.keys()]) + "\\\\ \n"
            table_header += "    \\midrule\n"

    table_endamble = """    \\bottomrule
  \\end{tabular}
\\end{table}"""

    table_str = table_preamble + table_format_str + table_header + table_content + table_endamble

    with open(table_filename, 'w') as fout:
        fout.write(table_str)

def plot_one_sim_figure(dataset: str,
                        y_label: str,
                        data_col: str,
                        divide_ms=False):
    # Plotting the function
    plt.figure(figsize=(4, 3))
    plt.xlabel("n")
    plt.ylabel(y_label)
    plt.grid(True)

    for alg in sim_alg_names:
        values, errors = get_sim_values_and_error(dataset, alg, data_col, divide_ms=divide_ms)
        ns = [i * get_sim_batch_size(dataset) for i in range(1, len(values) + 1)]
        plt.plot(ns,
                 values,
                 label=sim_alg_latex_names[alg],
                 linewidth=3, color=sim_alg_colors[alg])
        plt.fill_between(ns,
                         values-errors,
                         values+errors,
                         color=sim_alg_colors[alg],
                         alpha=0.3)

    plt.gca().ticklabel_format(axis='x', style='sci', scilimits=(3,3))
    plt.legend()
    plt.savefig(f"sim_{dataset}_{data_col}.pdf", format="pdf", bbox_inches="tight")
    plt.show()


def create_sim_dataset_figures(dataset: str):
    plot_one_sim_figure(dataset, "Total Time (s)", "total_time", divide_ms=True)
    plot_one_sim_figure(dataset, "Memory", "mem")
    plot_one_sim_figure(dataset, "ARI", "ari")
    plot_one_sim_figure(dataset, "NMI", "nmi")
    plot_one_sim_figure(dataset, "Update Time (s)", "update_time", divide_ms=True)


def create_all_sim_figures():
    for dataset in sim_datasets.keys():
        create_sim_dataset_figures(dataset)


def create_all_kde_figures():
    for dataset in kde_datasets.keys():
        create_kde_dataset_figures(dataset)


def main():
    create_all_sim_figures()
    create_sim_results_table_ea()
    create_sim_results_table_app()
    create_all_kde_figures()
    create_kde_results_table_ea()
    create_kde_results_table_app()
    create_kde_exact_results_table()
    create_dataset_information_tables()


if __name__ == '__main__':
    main()
