import matplotlib.pyplot as plt
import pandas as pd
from common import summary


fig, axes = plt.subplots(1, 3, figsize=(15,5))
table = summary('lenet_randomized_compare')

table[f'final_epsilon'] = pd.to_numeric(table[f'final_epsilon'], errors='coerce')
table[f'post_prune_final_randomized_eps'] = pd.to_numeric(table[f'post_prune_final_randomized_eps'], errors='coerce')
table[f'post_prune_final_randomized_eps_std'] = pd.to_numeric(table[f'post_prune_final_randomized_eps_std'], errors='coerce')
table[f'post_prune_final_randomized_eps_max'] = pd.to_numeric(table[f'post_prune_final_randomized_eps_max'], errors='coerce')
table[f'FinalAccGuarantee'] = pd.to_numeric(table[f'FinalAccGuarantee'], errors='coerce')

table[f'acc_epsilon'] = pd.to_numeric(table[f'acc_epsilon'], errors='coerce')
table[f'post_prune_randomized_eps'] = pd.to_numeric(table[f'post_prune_randomized_eps'], errors='coerce')
table[f'post_prune_randomized_eps_std'] = pd.to_numeric(table[f'post_prune_randomized_eps_std'], errors='coerce')
table[f'post_prune_randomized_eps_max'] = pd.to_numeric(table[f'post_prune_randomized_eps_max'], errors='coerce')
table[f'AccGuarantee'] = pd.to_numeric(table[f'AccGuarantee'], errors='coerce')

ax = axes[0]

rel = table[table['PruneType'] == 'L0']
ax.plot(rel['FinalAccGuarantee'], rel[f'final_epsilon'], label="Implied Epsilon Guarantee", color='green')
plot = ax.plot(rel['FinalAccGuarantee'], rel[f'post_prune_final_randomized_eps'], label="Avg Observed Error", color='darkmagenta')
lower = rel[f'post_prune_final_randomized_eps'] - rel[f'post_prune_final_randomized_eps_std']
upper = rel[f'post_prune_final_randomized_eps'] + rel[f'post_prune_final_randomized_eps_std']
lower[lower < 0] = 0
ax.fill_between(rel['FinalAccGuarantee'], lower, upper, alpha=0.2, color=plot[0].get_color())
ax.plot(rel['FinalAccGuarantee'], rel[f'post_prune_final_randomized_eps_max'], label="Max Error")
ax.legend()
ax.set_title("Easy Case")

ax = axes[1]

rel = table[(table['PruneType'] == 'DataIndCoreset') & (table['L1'] == 0.0001)]
ax.plot(rel['AccGuarantee'], rel[f'acc_epsilon'], label="Implied Epsilon Guarantee", color='green')
plot = ax.plot(rel['AccGuarantee'], rel[f'post_prune_randomized_eps'], label="Avg Observed Error", color='darkmagenta')
lower = rel[f'post_prune_randomized_eps'] - rel[f'post_prune_randomized_eps_std']
upper = rel[f'post_prune_randomized_eps'] + rel[f'post_prune_randomized_eps_std']
lower[lower < 0] = 0
ax.fill_between(rel['AccGuarantee'], lower, upper, alpha=0.2, color=plot[0].get_color())
ax.plot(rel['AccGuarantee'], rel[f'post_prune_randomized_eps_max'], label="Max Error")
ax.legend()
ax.set_title("Moderate Case")

ax = axes[2]

rel = table[(table['PruneType'] == 'DataIndCoreset') & (table['L1'] == 0.0)]
ax.plot(rel['AccGuarantee'], rel[f'acc_epsilon'], label="Implied Epsilon Guarantee", color='green')
plot = ax.plot(rel['AccGuarantee'], rel[f'post_prune_randomized_eps'], label="Avg Observed Error", color='darkmagenta')
lower = rel[f'post_prune_randomized_eps'] - rel[f'post_prune_randomized_eps_std']
upper = rel[f'post_prune_randomized_eps'] + rel[f'post_prune_randomized_eps_std']
lower[lower < 0] = 0
ax.fill_between(rel['AccGuarantee'], lower, upper, alpha=0.2, color=plot[0].get_color())
ax.plot(rel['AccGuarantee'], rel[f'post_prune_randomized_eps_max'], label="Max Error")
ax.legend()
ax.set_title("Difficult Case")

fig.tight_layout()
fig.savefig(f'plots_out/lenet_randomized_compare.png')
