
# %%
import numpy as np
import pandas as pd
import matplotlib
import matplotlib.pyplot as plt


# %%
df1 = pd.read_csv('./out/mlrepo12.csv')

# %%
metric = 'auc'
# tran_param = "{'space': 'clr'}"
# tran_param = "{'space': 'clr', 'std': True}"
tran_param = "{'space': 'prop'}"
# tran_param = "{'space': 'prop', 'std': True}"

# df1 = df1[df1['seed'] < 10]

# average over seeds
means = df1.pivot_table(values=['auc', 'bacc'], index=['data_idx', 'aug_params', 'tran_params', 'dr_params', 'head_params'], aggfunc='mean')
means = means.reset_index()
means = means[means['tran_params'] == tran_param]
print(means[means['head_params'] == "{'model': 'maml', 'aug': 'aitch'}"])




# %%
x = [1, 2, 4, 6, 8, 10]
y_idx = [
    "{}",
    "{'comb': 'rand', 'space': 'clr', 'factor': 2}",
    "{'comb': 'rand', 'space': 'clr', 'factor': 5}",
    "{'comb': 'rand', 'space': 'clr', 'factor': 10}",
    "{'comb': 'rand', 'space': 'clr', 'factor': 20}",
]

y_idx = [
    "{}",
    "{'conv': 'rand', 'space': 'clr', 'factor': 2}",
    "{'conv': 'rand', 'space': 'clr', 'factor': 4}",
    "{'conv': 'rand', 'space': 'clr', 'factor': 6}",
    "{'conv': 'rand', 'space': 'clr', 'factor': 8}",
    "{'conv': 'rand', 'space': 'clr', 'factor': 10}",
]

x = [0, 0.2, 0.4, 0.6, 0.8]
y_idx = [
    "{}",
    "{'conv': 'rand', 'space': 'clr', 'weight': 0.2, 'factor': 10}",
    "{'conv': 'rand', 'space': 'clr', 'weight': 0.4, 'factor': 10}",
    "{'conv': 'rand', 'space': 'clr', 'weight': 0.6, 'factor': 10}",
    "{'conv': 'rand', 'space': 'clr', 'weight': 0.8, 'factor': 10}",
]
# y_idx = [
#     "{}",
#     "{'mult': True, 'space': '', 'weight': 0.2, 'factor': 10}",
#     "{'mult': True, 'space': '', 'weight': 0.4, 'factor': 10}",
#     "{'mult': True, 'space': '', 'weight': 0.6, 'factor': 10}",
#     "{'mult': True, 'space': '', 'weight': 0.8, 'factor': 10}",
# ]

x = [0, 0.5]
y_idx = [
    "{}",
    "{'mult': True, 'conv': 'rand', 'comb': 'rand', 'space': 'clr', 'weight': 0.2, 'factor': 10}",
]
y_idx = [
    "{}",
    "{'conv': 'rand', 'space': 'clr', 'weight': 0.5, 'factor': 10}",
]
y_idx = [
    "{}",
    "{'comb': 'rand', 'space': 'clr', 'weight': 0.6, 'factor': 10}",
]
y_idx = [
    "{}",
    "{'mult': True, 'space': '', 'weight': 0.6, 'factor': 10}",
]


head_params = [
    # "{'model': 'mlp'}",
    "{'model': 'rf'}",
    # "{'model': 'nn'}",
    "{'model': 'xgb'}",
    "{'model': 'maml'}",
    # "{'model': 'ridge'}",
    "{'model': 'deepcoda'}",
    "{'model': 'metann'}",
]

means['aug_params'][means['head_params'] == "{'model': 'maml', 'aug': 'aitch'}"] = y_idx[1]
means['head_params'][means['head_params'] == "{'model': 'maml', 'aug': 'aitch'}"] = "{'model': 'maml'}"

means = means.round(2)

mat = []
for i in range(12):
    out = str(i+1)
    row = []
    for head in head_params:
        out += ' & '
        tmp_df = means[means['data_idx'] == i]
        tmp_df = tmp_df[tmp_df['head_params'] == head]
        val = tmp_df[metric][tmp_df['aug_params'] == y_idx[0]].values[0]
        val_aug = tmp_df[metric][tmp_df['aug_params'] == y_idx[1]].values[0]
        row.append(val)
        row.append(val_aug)
        if val == max(val, val_aug):
            out += '\\textbf{' + "{:.2f}".format(val) + '}'
        else:
            out += "{:.2f}".format(val)
        out += ' & '
        if val_aug == max(val, val_aug):
            out += '\\textbf{' + "{:.2f}".format(val_aug) + '}'
        else:
            out += "{:.2f}".format(val_aug)
    mat.append(row)
    out += '\\\\'
    print(out)

out = '\\midrule'
print(out)

mat = np.array(mat)
out = 'Mean'
for j in range(mat.shape[1]):
    out += ' & '
    if j % 2 == 1:
        out += '\\textbf{'
    out += "{:.2f}".format(np.mean(mat[:, j]))
    if j % 2 == 1:
        out += '}'

out += '\\\\'
print(out)
