import matplotlib
import matplotlib.pyplot as plt
from matplotlib.ticker import FuncFormatter
from common import summary
import pandas as pd

matplotlib.rcParams.update({'font.size':14})

table = summary('lenet_finetuning_gap')

table[f'Sparsity'] = pd.to_numeric(table[f'Sparsity'], errors='coerce')
table[f'post_prune_acc'] = pd.to_numeric(table[f'post_prune_acc'], errors='coerce')
table[f'post_prune_eps'] = pd.to_numeric(table[f'post_prune_eps'], errors='coerce')
table[f'post_prune_eps_std'] = pd.to_numeric(table[f'post_prune_eps_std'], errors='coerce')
table[f'acc_implied'] = pd.to_numeric(table[f'acc_implied'], errors='coerce')
table[f'epsilon_implied'] = pd.to_numeric(table[f'epsilon_implied'], errors='coerce')
table[f'post_finetune_acc'] = pd.to_numeric(table[f'post_finetune_acc'], errors='coerce')
table[f'post_finetune_eps_std'] = pd.to_numeric(table[f'post_finetune_eps_std'], errors='coerce')

fig, (axes1, axes2 ) = plt.subplots(1,2, figsize=(10,5))
fig.suptitle("Deterministic Weight Pruning LeNet-300-100")

plot1 = axes1.plot(table['Sparsity'], table['acc_implied'] * table.iloc[0]['post_prune_acc'], label="Guaranteed Accuracy")
plot2 = axes1.plot(table['Sparsity'], table['post_prune_acc'], label="Post Prune Accuracy")
plot3 = axes1.plot(table['Sparsity'], table['post_finetune_acc'], label="Post Finetune Accuracy")

axes1.fill_between(table['Sparsity'], table['acc_implied'], table['post_prune_acc'], alpha=0.2, color=plot1[0].get_color())
axes1.fill_between(table['Sparsity'], table['post_prune_acc'], table['post_finetune_acc'], alpha=0.2, color=plot2[0].get_color())

axes1.annotate("Loose Bounds", xy=(0.3,0.7))
axes1.annotate("Fine-tuning Gap", xy=(0.9,0.9), xytext=(0.5,0.5), arrowprops=dict(arrowstyle="->", connectionstyle="arc3"))

axes1.set_xlabel('Sparsity')
axes1.set_ylabel('Accuracy')
axes1.legend(loc='lower right')
axes1.set_title("Accuracy")

axes2.plot(table['Sparsity'], table['post_prune_eps'], label="Post Prune Approx Err", color=plot2[0].get_color())
lower = table[f'post_prune_eps'] - table[f'post_prune_eps_std']
upper = table[f'post_prune_eps'] + table[f'post_prune_eps_std']
lower[lower < 0] = 0
axes2.fill_between(table['Sparsity'], lower, upper, alpha=0.2, color=plot2[0].get_color())

axes2.plot(table['Sparsity'], table['post_finetune_eps'], label="Post Finetune Approx Err", color=plot3[0].get_color())
lower = table[f'post_finetune_eps'] - table[f'post_finetune_eps_std']
upper = table[f'post_finetune_eps'] + table[f'post_finetune_eps_std']
lower[lower < 0] = 0
axes2.fill_between(table['Sparsity'], lower, upper, alpha=0.2, color=plot3[0].get_color())

axes2.annotate("Fine-tuning Increases \n Approx Error", xy=(0.9,40), xytext=(0.2,40), arrowprops=dict(arrowstyle="->", connectionstyle="arc3"))

axes2.set_xlabel('Sparsity')
axes2.set_ylabel('Approximation Error')
axes2.set_title("Approximation Error")
axes2.legend()

fig.tight_layout(rect=[0, 0, 1, 0.95])
fig.savefig('plots_out/lenet_finetuning_gap.png')
