import matplotlib
import matplotlib.pyplot as plt
from matplotlib.ticker import FuncFormatter
from common import summary
import pandas as pd

matplotlib.rcParams.update({'font.size':14})

table = summary('lenet_sparsity_first')

table[f'Sparsity'] = pd.to_numeric(table[f'Sparsity'], errors='coerce')
table[f'post_prune_acc'] = pd.to_numeric(table[f'post_prune_acc'], errors='coerce')
table[f'post_prune_eps'] = pd.to_numeric(table[f'post_prune_eps'], errors='coerce')
table[f'post_prune_eps_std'] = pd.to_numeric(table[f'post_prune_eps_std'], errors='coerce')
table[f'acc_implied'] = pd.to_numeric(table[f'acc_implied'], errors='coerce')
table[f'epsilon_implied'] = pd.to_numeric(table[f'epsilon_implied'], errors='coerce')

fig, axes = plt.subplots(2,2, figsize=(10,10))
fig.suptitle('Sparsity-First Pruning')

col = 0
for prune_type in ['DataIndCoreset', 'DataDepDet']:
    rel = table[(table['PruneType'] == prune_type)]
    axes[0,col].plot(rel['Sparsity'], rel['acc_implied'] * rel.iloc[0]['post_prune_acc'], label="Guaranteed Accuracy")
    axes[0,col].plot(rel['Sparsity'], rel['post_prune_acc'], label="Post Prune Accuracy")
    axes[0,col].set_xlabel('Sparsity')
    axes[0,col].set_ylabel('Accuracy')
    axes[0,col].legend()
    title = {
        'DataIndCoreset': "Neuron Pruning",
        'DataDepDet': "Deterministic Weight Pruning"
    }.get(prune_type)
    axes[0,col].set_title(title)

    axes[1,col].plot(rel['Sparsity'], rel['epsilon_implied'], label="Guaranteed Error", color='green')
    axes[1,col].plot(rel['Sparsity'], rel['post_prune_eps'], label="Observed Err", color='darkmagenta')
    lower = rel[f'post_prune_eps'] - rel[f'post_prune_eps_std']
    upper = rel[f'post_prune_eps'] + rel[f'post_prune_eps_std']
    lower[lower < 0] = 0
    axes[1,col].fill_between(rel['Sparsity'], lower, upper, alpha=0.2, color='darkmagenta')

    axes[1,col].set_xlabel('Sparsity')
    axes[1,col].set_ylabel('Approximation Error')
    axes[1,col].legend()

    col += 1

fig.savefig('plots_out/lenet_sparsity_first.png')
