import os

import pandas as pd

from kal.active_strategies import STRATEGIES, NAME_MAPPINGS, KALS, NAME_MAPPINGS_ABLATION_STUDY, KAL_75, KAL_50, KAL_25

datasets = [
    "xor",
    "iris",
    "animals",
    "cub200",
]
dataset_mappings = {
    "xor": "XOR",
    "iris": "Iris",
    "animals": "Animals",
    "cub200": "CUB200",
}


dfs = []
for dataset in datasets:
    result_df = pd.read_pickle(os.path.join(dataset, "results_kal.pkl"))

    if "Accuracy" in result_df:
        result_df['Test Accuracy'] = result_df['Accuracy']

    df_mean = result_df.groupby(["Strategy"]).mean()['Test Accuracy']
    # baseline = df_mean['Random']
    # df_mean_percentage = (df_mean / baseline - 1) * 100

    df_mean_over_seeds = result_df.groupby(["Strategy", "Seed"]).mean()
    # df_mean_over_seeds_percentage = (df_mean_over_seeds / baseline - 1) * 100
    df_std = df_mean_over_seeds.groupby('Strategy').std()['Test Accuracy']

    strategies = []
    aucs = []
    for i, (mean, std) in enumerate(zip(df_mean, df_std)):
        strategy = df_mean.index[i]
        if strategy not in KALS: #or strategy in [KAL_25, KAL_50, KAL_75]:
            continue
        strategies.append(NAME_MAPPINGS_ABLATION_STUDY[strategy])
        aucs.append(f"${mean:.2f}$ {{\\tiny $\\pm {std:.2f}$ }}")

    df_dataset = pd.DataFrame({
        "Strategy": strategies,
        f"{dataset_mappings[dataset]}": aucs
    }).set_index(['Strategy'])
    dfs.append(df_dataset)

dfs = pd.concat(dfs, axis=1, join="outer")

print(dfs.to_latex(float_format="%.2f", escape=False))
with open(f"auc_table.txt", "w") as f:
    f.write(dfs.to_latex(float_format="%.2f", escape=False))
