"""
make a [dataset, train_size] table of some value, sort by worst
"""


import pandas as pd
import _hw.utils as u
pd.options.display.float_format = '{:.3f}'.format


def print_table(df: pd.DataFrame, predictor="lin_reg", predictor2="xgb", metric="kendalltau"):
    df1 = df[df["predictor"] == predictor]
    df2 = df[df["predictor"] == predictor2]

    # only required stats
    of_interest = ["dataset", "seed", "train_size", metric]
    df1 = df1[of_interest]

    # mean data
    data_mean = {}
    sizes = df1["train_size"].unique()

    datasets = df1["dataset"].unique()
    for ds in datasets:
        df12 = df1[df1["dataset"] == ds]
        df_mean = df12.groupby(by="train_size").mean().reset_index()
        data_mean[ds] = df_mean[metric]

        df22 = df2[df2["dataset"] == ds]
        df_mean = df22.groupby(by="train_size").mean().reset_index()
        data_mean[ds] = data_mean[ds].append(df_mean[metric].tail(1))

    df_mean = pd.DataFrame(data_mean).T
    sizes_ext = [s for s in sizes]
    sizes_ext.append(predictor2)
    df_mean.columns = sizes_ext
    df_mean = df_mean.sort_values(sizes_ext[-2], ascending=True)
    print(df_mean.to_latex())


if __name__ == '__main__':
    print_table(u.get_result_data("simple_hwnas"), predictor="lin_reg", predictor2="xgb")
    print_table(u.get_result_data("results_transnas"), predictor="lin_reg", predictor2="xgb")
