import matplotlib.pyplot as plt
from matplotlib.ticker import FuncFormatter
from common import summary
import pandas as pd


fig, axes = plt.subplots(6, 3, figsize=(15,15), sharex='col', sharey='row')
table = summary('lenet_easy_mode')

table[f'final_samples'] = pd.to_numeric(table[f'final_samples'], errors='coerce')
table[f'final_epsilon'] = pd.to_numeric(table[f'final_epsilon'], errors='coerce')
table[f'final_sparsity'] = pd.to_numeric(table[f'final_sparsity'], errors='coerce')
table[f'post_prune_final_eps'] = pd.to_numeric(table[f'post_prune_final_eps'], errors='coerce')
table[f'post_prune_final_eps_std'] = pd.to_numeric(table[f'post_prune_final_eps_std'], errors='coerce')
table[f'post_prune_final_eps_max'] = pd.to_numeric(table[f'post_prune_final_eps_max'], errors='coerce')
table[f'FinalAccGuarantee'] = pd.to_numeric(table[f'FinalAccGuarantee'], errors='coerce')

plt.gcf().text(0.01, 0.9, "Data Ind Coreset\n(Layer Size = 100)", fontsize=12)
plt.gcf().text(0.01, 0.55, "Data Dep Coreset\n(Layer Size = 1K)", fontsize=12)
plt.gcf().text(0.01, 0.25, "Data Dep\nDeterministic\n(Layer Size = 1K)", fontsize=12)
row = 0
axes[0,0].set_title('Easy Case')
for final_prune_type in ['DataIndCoreset', 'DataDepCoreset', 'DataDepDet']:
    # Plot the samples required versus the guaranteed accuracy
    rel = table[(table['FinalPruneType'] == final_prune_type)]
    ax = axes[row, 0]
    row += 1

    samples_line = ax.plot(rel['FinalAccGuarantee'], rel[f'final_samples'], label="Samples Req", color='firebrick')
    ax.set_ylim((0,None))
    if final_prune_type == 'DataDepCoreset':
        ax.yaxis.set_major_formatter(FuncFormatter(lambda x, pos: '%1.0fK' % (x * 1e-3)))
        ax.set_ylim((0,2000000))
    elif final_prune_type == 'DataDepDet':
        ax.set_ylim((0,1010))
    else:
        ax.yaxis.set_major_formatter(FuncFormatter(lambda x, pos: '%1.0fK' % (x * 1e-3)))

    ax2 = ax.twinx()
    color = 'tab:blue'
    sparsity_line = ax2.plot(rel['FinalAccGuarantee'], rel[f'final_sparsity'], color=color, label="Sparsity Achieved")
    ax2.tick_params(axis='y', labelcolor=color)
    ax2.set_ylim((0,1))
    ax2.set_yticklabels([])

    lines = samples_line + sparsity_line
    ax.legend(lines, [l.get_label() for l in lines])

    # Also plot the guaranteed accuracy with the epsilon guarantee and observed approximation error.
    ax = axes[row, 0]
    row += 1
    if row == 6:
        ax.set_xlabel("Guaranteed Remaining Accuracy")
    ax.plot(rel['FinalAccGuarantee'], rel[f'final_epsilon'], label="Implied Epsilon Guarantee", color='green')
    plot = ax.plot(rel['FinalAccGuarantee'], rel[f'post_prune_final_eps'], label="Avg Observed Error", color='darkmagenta')
    lower = rel[f'post_prune_final_eps'] - rel[f'post_prune_final_eps_std']
    upper = rel[f'post_prune_final_eps'] + rel[f'post_prune_final_eps_std']
    lower[lower < 0] = 0
    ax.fill_between(rel['FinalAccGuarantee'], lower, upper, alpha=0.2, color=plot[0].get_color())
    # ax.plot(rel['FinalAccGuarantee'], rel[f'post_prune_final_eps_max'], label="Max Error")
    ax.legend()

col = 0
for plan in ['lenet_mod_mode', 'lenet_hard_mode']:
    col += 1
    table = summary(plan)
    table[f'required_samples'] = pd.to_numeric(table[f'required_samples'], errors='coerce')
    table[f'AccGuarantee'] = pd.to_numeric(table[f'AccGuarantee'], errors='coerce')
    table[f'post_prune_acc'] = pd.to_numeric(table[f'post_prune_acc'], errors='coerce')
    table[f'achieved_sparsity'] = pd.to_numeric(table[f'achieved_sparsity'], errors='coerce')
    table[f'post_prune_eps'] = pd.to_numeric(table[f'post_prune_eps'], errors='coerce')
    table[f'post_prune_eps_std'] = pd.to_numeric(table[f'post_prune_eps_std'], errors='coerce')
    table[f'acc_epsilon'] = pd.to_numeric(table[f'acc_epsilon'], errors='coerce')

    row = 0
    axes[0,col].set_title('Moderate Case' if plan == 'lenet_mod_mode' else "Difficult Case")
    for prune_type in ['DataIndCoreset', 'DataDepCoreset', 'DataDepDet']:
        rel = table[(table['PruneType'] == prune_type)]
        ax = axes[row, col]
        row += 1

        # Plot the samples required versus the guaranteed accuracy
        color = 'firebrick'
        samples_line = ax.plot(rel['AccGuarantee'], rel[f'required_samples'], color=color, label="Samples Req")
        if prune_type in ['DataIndCoreset', 'DataDepCoreset']:
            ax.yaxis.set_major_formatter(FuncFormatter(lambda x, pos: '%1.0fK' % (x * 1e-3)))
        ax.tick_params(axis='y', labelcolor=color)

        ax2 = ax.twinx()
        color = 'tab:blue'
        sparsity_line = ax2.plot(rel['AccGuarantee'], rel[f'achieved_sparsity'], color=color, label="Sparsity Achieved")
        ax2.tick_params(axis='y', labelcolor=color)
        ax2.set_ylim((0,1))
        if col != 2:
            ax2.set_yticklabels([])

        lines = samples_line + sparsity_line
        ax.legend(lines, [l.get_label() for l in lines])

        ax = axes[row, col]
        row += 1
        if row == 6:
            ax.set_xlabel("Guaranteed Remaining Accuracy")
        ax.plot(rel['AccGuarantee'], rel[f'acc_epsilon'], label="Implied Eps Guarantee", color="green")
        plot = ax.plot(rel['AccGuarantee'], rel[f'post_prune_eps'], label="Avg Observed Err", color="darkmagenta")
        lower = rel[f'post_prune_eps'] - rel[f'post_prune_eps_std']
        upper = rel[f'post_prune_eps'] + rel[f'post_prune_eps_std']
        lower[lower < 0] = 0
        ax.fill_between(rel['AccGuarantee'], lower, upper, alpha=0.2, color=plot[0].get_color())
        # ax.plot(rel['AccGuarantee'], rel[f'post_prune_eps_max'], label="Max Error")
        ax.legend()

fig.tight_layout()
fig.subplots_adjust(left=0.15)
fig.savefig(f'plots_out/lenet_easy_med_hard.png')
