from common import summary
import matplotlib.pyplot as plt
import math

table = summary('total_sensitivity')

def samples(total_sensitivity):
    return (6 + 2 * 0.5) * total_sensitivity * 2 * math.log(4 / 0.5) / 0.5**2

fig, axes = plt.subplots()

rel = table[(table["PruneType"] == prune_type)]
axes.set_xlabel("Sparsity")
axes.set_ylabel("Dev Acc")
if metric == "acc":
    axes.set_ylim((0.5,1))
if "eps" in metric:
    axes.set_ylim((0,10))
axes.plot(rel['Sparsity'], rel[f'post_prune_{metric}'], label="Post-Prune")
axes.plot(rel['Sparsity'], rel[f'post_finetune_{metric}'], label="Post-Finetune")
axes.set_title("LeNet-300-100 Pruning Dev Acc")
axes.legend()
fig.savefig(f'plots_out/lenet_easy_mode/sparsity_vs_{metric}_{prune_type}_pruning.png')
