
# %%
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt


# %%
df1 = pd.read_csv('./out/mlrepo12.csv')

# average over seeds
means = df1.pivot_table(
    values=['auc', 'bacc', 'acc_bl'], 
    index=['data_idx', 'aug_params', 'tran_params', 'dr_params', 'head_params']
)
means = means.reset_index()

# %%

metric = 'auc'
trans = "clr"
head_params = "{'model': 'mlp'}"
means = means[means['tran_params'] == trans]

# %%

rows = np.logical_and(
    means['aug_params'] == "{}",
    np.logical_and(
        means['dr_params'] == '{}',
        means['head_params'] == head_params
        # means['head_params'] == "{'model': 'ridge'}"
    )
)
df2 = means[rows]

# %%

rows = np.logical_and(
    # means['aug_params'] == '{}',
    # means['aug_params'] == "{'conv': 'half', 'space': 'clr'}",
    # means['aug_params'] == "{'conv': 'rand', 'space': 'clr', 'factor': 10}",
    means['aug_params'] == "{'mult': True, 'conv': 'rand', 'comb': 'rand', 'space': 'clr', 'weight': 0.2, 'factor': 10}",
    # means['aug_params'] == "{'comb': 'rand', 'space': 'clr', 'factor': 10}",
    # means['aug_params'] == "{'conv': 'rand', 'space': 'clr', 'factor': 10}",
    np.logical_and(
        means['dr_params'] == '{}',
        means['head_params'] == head_params
    )
)
df3 = means[rows]

# %%

tmp = df3[metric].to_numpy() - df2[metric].to_numpy()

# %%

plt.scatter(df2['acc_bl'], tmp)
plt.axhline(0)

plt.xlabel('Class imbalance')
plt.ylabel('Difference in ' + metric + ' from data augmentation')

# plt.title("Augmentation with convex combinations")
# plt.title("Augmentation with subcompositions")
plt.title("Aitchison convex combinations")
plt.savefig('fig.pdf')
plt.show()

# %%

df4 = df1.pivot_table(index=['aug_params', 'dr_params', 'head_params'], values=['acc', 'bacc', 'auc'])
df4 = df4.reset_index()
df4.sort_values(by='auc')

# df4 = df3.join(df2, left_on='data_idx', right_on='data_idx')


# %%

# %%
