
# %%
import numpy as np
import pandas as pd
import matplotlib
import matplotlib.pyplot as plt


###################################
# Required to avoid type3 fonts that break ICML submission pdf
matplotlib.rcParams['pdf.fonttype'] = 42
matplotlib.rcParams['ps.fonttype'] = 42
###################################

###################################
# Storing colors here
blue = "#377eb8"
purple = "#984ea3"
orange = "#ff7f00"
brown = "#a65628"
pink = "#f781bf"
grey = "#999999"
# green = "#4daf4a"
# red = "#e41a1c"
# yellow = "#ffff33"
####################################

matplotlib.rc('xtick', labelsize=13)
matplotlib.rc('ytick', labelsize=13)
font = {'weight' : 'normal',
        'size'   : 13}
matplotlib.rc('font', **font)
# plt.ylim(-45,15)


# %%
df1 = pd.read_csv('./out/mlrepo12.csv')

# %%
metric = 'auc'
# tran_param = "{'space': 'clr'}"
# tran_param = "{'space': 'clr', 'std': True}"
tran_param = "{'space': 'prop'}"
# tran_param = "{'space': 'prop', 'std': True}"

# df1 = df1[df1['seed'] < 10]

# average over seeds
means = df1.pivot_table(values=['auc', 'bacc'], index=['data_idx', 'aug_params', 'tran_params', 'dr_params', 'head_params'], aggfunc='mean')
means = means.reset_index()
means = means[means['tran_params'] == tran_param]
print(means[means['head_params'] == "{'model': 'maml', 'aug': 'aitch'}"])
# %%
x = [1, 2, 4, 6, 8, 10]
y_idx = [
    "{}",
    "{'comb': 'rand', 'space': 'clr', 'factor': 2}",
    "{'comb': 'rand', 'space': 'clr', 'factor': 5}",
    "{'comb': 'rand', 'space': 'clr', 'factor': 10}",
    "{'comb': 'rand', 'space': 'clr', 'factor': 20}",
]

y_idx = [
    "{}",
    "{'conv': 'rand', 'space': 'clr', 'factor': 2}",
    "{'conv': 'rand', 'space': 'clr', 'factor': 4}",
    "{'conv': 'rand', 'space': 'clr', 'factor': 6}",
    "{'conv': 'rand', 'space': 'clr', 'factor': 8}",
    "{'conv': 'rand', 'space': 'clr', 'factor': 10}",
]

x = [0, 0.2, 0.4, 0.6, 0.8]
y_idx = [
    "{}",
    "{'conv': 'rand', 'space': 'clr', 'weight': 0.2, 'factor': 10}",
    "{'conv': 'rand', 'space': 'clr', 'weight': 0.4, 'factor': 10}",
    "{'conv': 'rand', 'space': 'clr', 'weight': 0.6, 'factor': 10}",
    "{'conv': 'rand', 'space': 'clr', 'weight': 0.8, 'factor': 10}",
]
# y_idx = [
#     "{}",
#     "{'mult': True, 'space': '', 'weight': 0.2, 'factor': 10}",
#     "{'mult': True, 'space': '', 'weight': 0.4, 'factor': 10}",
#     "{'mult': True, 'space': '', 'weight': 0.6, 'factor': 10}",
#     "{'mult': True, 'space': '', 'weight': 0.8, 'factor': 10}",
# ]

# x = [0, 0.5]
# y_idx = [
#     "{}",
#     "{'mult': True, 'conv': 'rand', 'comb': 'rand', 'space': 'clr', 'weight': 0.2, 'factor': 10}",
# ]
# y_idx = [
#     "{}",
#     "{'conv': 'rand', 'space': 'clr', 'weight': 0.5, 'factor': 10}",
# ]



# y_idx = [
#     "{}",
#     "{'conv': 'rand', 'space': 'prop', 'factor': 2}",
#     "{'conv': 'rand', 'space': 'prop', 'factor': 4}",
#     "{'conv': 'rand', 'space': 'prop', 'factor': 6}",
#     "{'conv': 'rand', 'space': 'prop', 'factor': 8}",
#     "{'conv': 'rand', 'space': 'prop', 'factor': 10}",
# ]

# y_idx = [
#     "{}",
#     "{'mult': True, 'conv': 'rand', 'comb': 'rand', 'space': 'prop', 'factor': 1.33}",
#     "{'mult': True, 'conv': 'rand', 'comb': 'rand', 'space': 'prop', 'factor': 2}",
#     "{'mult': True, 'conv': 'rand', 'comb': 'rand', 'space': 'prop', 'factor': 2.66}",
#     "{'mult': True, 'conv': 'rand', 'comb': 'rand', 'space': 'prop', 'factor': 3.33}",
#     "{'mult': True, 'conv': 'rand', 'comb': 'rand', 'space': 'prop', 'factor': 4}",
# ]

# y_idx = [
#     "{}",
#     "{'mult': True, 'space': '', 'factor': 2}",
#     "{'mult': True, 'space': '', 'factor': 4}",
#     "{'mult': True, 'space': '', 'factor': 6}",
#     "{'mult': True, 'space': '', 'factor': 8}",
#     "{'mult': True, 'space': '', 'factor': 10}",
# ]

head_params = [
    # "{'model': 'mlp'}",
    "{'model': 'rf'}",
    "{'model': 'deepcoda'}",
    # "{'model': 'nn'}",
    "{'model': 'xgb'}",
    # "{'model': 'ridge'}",
    "{'model': 'metann'}",
]
# y_idx = [
#     "{}",
#     "{'conv': 'rand', 'comb': 'rand', 'space': 'clr', 'factor': 2}",
#     "{'conv': 'rand', 'comb': 'rand', 'space': 'clr', 'factor': 5}",
#     "{'conv': 'rand', 'comb': 'rand', 'space': 'clr', 'factor': 10}",
# ]

# Set size of subplots
# plt.rcParams['figure.figsize'] = [12, 6]


# loop over datasets and make 12 subplots:
# - one for each dataset
fig, axs = plt.subplots(2, 5, figsize=(12, 6))
axs = axs.flatten()
plot_idx = 0
for data_idx in range(12):
    if data_idx in [2,4]:
        continue
    try:
        tmp_df = means[means['data_idx'] == data_idx]
        tmp_df = tmp_df.set_index("aug_params")
        # tmp_df.plot(y=metric, ax=axs[data_idx])
        # axs[data_idx].set_title(f"{data_idx}")

        ymax = []
        ymin = []
        for head in head_params:
            tmp_df2 = tmp_df[tmp_df['head_params'] == head]
            y = tmp_df2.loc[y_idx][metric]
            axs[plot_idx].plot(x, y, label=head)
            ymax.append(y.max())
            ymin.append(y.min())
        ymax = max(ymax)
        ymin = min(ymin)
        yrange = ymax - ymin
        if yrange < 0.1:
            axs[plot_idx].set_ylim(ymin - (0.1 - yrange) / 2, ymax + (0.1 - yrange) / 2)
        # axs[data_idx].set_ylim([0.5, 1])
        axs[plot_idx].set_title(f"Dataset {plot_idx + 1}")
        # add ticks at 1, 3, 5, 7, 9
        # axs[plot_idx].set_xticks([1, 4, 7, 10])
    except:
        print("Data idx", data_idx, "not found")
    plot_idx += 1

axs[0].set_ylabel("Test " + metric)
axs[5].set_ylabel("Test " + metric)
axs[7].set_xlabel("Data augmentation factor")

# Custom legend:
from matplotlib.lines import Line2D
custom_lines = [
    Line2D([0], [0], lw=2, color='blue'),
    Line2D([0], [0], lw=2, color='orange'),
]

# plt.legend(custom_lines, ["MetaNN", "clr-RF"], loc='lower center')
plt.legend()

fig.tight_layout()

# save figure
fig.savefig("./out/mlrepo12.pdf")

plt.show()

raise ValueError("stop")



# average over datasets
means = means.pivot_table(values=['auc', 'bacc'], index=['aug_params', 'tran_params', 'dr_params', 'head_params'])
means = means.reset_index()

print(means.sort_values('auc').tail(40))

# %%

trans = "clr"
means = means[means['tran_params'] == trans]

# print(means.sort_values('auc'))


# %%


head_params = "{'model': 'mlp'}"
means_head = means[means['head_params'] == head_params].set_index("aug_params")
y = means_head.loc[y_idx][metric]
plt.plot(x, y, label='NN')

# %%

# head_params = "{'model': 'mlp', 'early': True}"
# means_head = means[means['head_params'] == head_params].set_index("aug_params")
# y = means_head.loc[y_idx][metric]
# plt.plot(x, y)

# %%

head_params = "{'model': 'rf'}"
means_head = means[means['head_params'] == head_params].set_index("aug_params")
y = means_head.loc[y_idx][metric]
plt.plot(x, y, label='RF')
# %%

# head_params = "{'model': 'svm'}"
# means_head = means[means['head_params'] == head_params].set_index("aug_params")
# y = means_head.loc[y_idx][metric]
# plt.plot(x, y, label='SVM')

# %%

head_params = "{'model': 'ridge'}"
means_head = means[means['head_params'] == head_params].set_index("aug_params")
y = means_head.loc[y_idx][metric]
plt.plot(x, y, label='CLR-ridge')

# %%
plt.legend()
plt.xlabel("Augmentation factor (*n)")
plt.ylabel("Average test Balanced Accuracy")
plt.savefig("bacc.pdf")
plt.show()
