import matplotlib.pyplot as plt
from matplotlib.ticker import FuncFormatter
from common import summary
import pandas as pd


fig, axes = plt.subplots()
table = summary('lenet_dropout')

table[f'required_samples'] = pd.to_numeric(table[f'required_samples'], errors='coerce')
table[f'AccGuarantee'] = pd.to_numeric(table[f'AccGuarantee'], errors='coerce')
table[f'post_prune_acc'] = pd.to_numeric(table[f'post_prune_acc'], errors='coerce')
table[f'achieved_sparsity'] = pd.to_numeric(table[f'achieved_sparsity'], errors='coerce')
table[f'post_prune_eps'] = pd.to_numeric(table[f'post_prune_eps'], errors='coerce')
table[f'post_prune_eps_std'] = pd.to_numeric(table[f'post_prune_eps_std'], errors='coerce')
table[f'acc_epsilon'] = pd.to_numeric(table[f'acc_epsilon'], errors='coerce')

axes.set_title('LeNet-300-100 Neuron Pruning With Dropout=0.5')
color = 'firebrick'
samples_line = axes.plot(table['AccGuarantee'], table[f'required_samples'], color=color, label="Samples Req")
axes.yaxis.set_major_formatter(FuncFormatter(lambda x, pos: '%1.0fK' % (x * 1e-3)))
axes.tick_params(axis='y', labelcolor=color)

axes2 = axes.twinx()
color = 'tab:blue'
sparsity_line = axes2.plot(table['AccGuarantee'], table[f'achieved_sparsity'], color=color, label="Sparsity Achieved")
axes2.tick_params(axis='y', labelcolor=color)
axes2.set_ylim((0,1))

lines = samples_line + sparsity_line
axes.legend(lines, [l.get_label() for l in lines])

fig.savefig(f'plots_out/lenet_dropout.png')
