import matplotlib.pyplot as plt
from matplotlib.ticker import FuncFormatter
import pandas as pd
from common import summary


fig, axes = plt.subplots(2, 4, figsize=(15,10))
table = summary('easy_mode_scaling')

table[f'final_samples'] = pd.to_numeric(table[f'final_samples'], errors='coerce')
table[f'final_epsilon'] = pd.to_numeric(table[f'final_epsilon'], errors='coerce')
table[f'final_sparsity'] = pd.to_numeric(table[f'final_sparsity'], errors='coerce')
table[f'post_prune_final_eps'] = pd.to_numeric(table[f'post_prune_final_eps'], errors='coerce')
table[f'post_prune_final_eps_std'] = pd.to_numeric(table[f'post_prune_final_eps_std'], errors='coerce')
table[f'post_prune_final_eps_max'] = pd.to_numeric(table[f'post_prune_final_eps_max'], errors='coerce')
table[f'FinalAccGuarantee'] = pd.to_numeric(table[f'FinalAccGuarantee'], errors='coerce')
table[f'FinalScale'] = pd.to_numeric(table[f'FinalScale'], errors='coerce')

col = -1
for scale in [1, 100, 1000, 10000]:
    col += 1
    # Plot the samples required versus the guaranteed accuracy
    rel = table[(table['FinalScale'] == scale)]
    ax = axes[0, col]
    ax.set_title(r'$\alpha$ = ' + f'{scale}')

    samples_line = ax.plot(rel['FinalAccGuarantee'], rel[f'final_samples'], label="Samples Req", color='firebrick')
    ax.set_ylim((0,None))
    if scale == 1:
        ax.yaxis.set_major_formatter(FuncFormatter(lambda x, pos: '%1.0fK' % (x * 1e-3)))

    ax.legend()

    # Also plot the guaranteed accuracy with the epsilon guarantee and observed approximation error.
    ax = axes[1, col]
    ax.plot(rel['FinalAccGuarantee'], rel[f'final_epsilon'], label="Implied Epsilon Guarantee", color='green')
    plot = ax.plot(rel['FinalAccGuarantee'], rel[f'post_prune_final_eps'], label="Avg Observed Error", color='darkmagenta')
    lower = rel[f'post_prune_final_eps'] - rel[f'post_prune_final_eps_std']
    upper = rel[f'post_prune_final_eps'] + rel[f'post_prune_final_eps_std']
    lower[lower < 0] = 0
    ax.fill_between(rel['FinalAccGuarantee'], lower, upper, alpha=0.2, color=plot[0].get_color())
    # axes.plot(rel['FinalAccGuarantee'], rel[f'post_prune_final_eps_max'], label="Max Error")
    ax.legend()

fig.tight_layout()
fig.savefig(f'plots_out/easy_mode_scaling.png')
