import matplotlib.pyplot as plt
import pandas as pd
import os
from common import summary

table = summary("lenet_easy_mode")
table['Sparsity'] = pd.to_numeric(table['Sparsity'], errors='raise')

try:
    os.makedirs('plots_out/lenet_easy_mode')
except Exception:
    pass

# Plot metrics for the first round of pruning (Magnitude & L0 pruning)
for metric in ["acc", "eps", "hidden_eps"]:
    table[f'post_prune_{metric}'] = pd.to_numeric(table[f'post_prune_{metric}'], errors='coerce')
    table[f'post_finetune_{metric}'] = pd.to_numeric(table[f'post_finetune_{metric}'], errors='coerce')
    for prune_type in ["Magnitude", "L0"]:
        fig, axes = plt.subplots()

        rel = table[(table["PruneType"] == prune_type)]
        axes.set_xlabel("Sparsity")
        axes.set_ylabel("Dev Acc")
        if metric == "acc":
            axes.set_ylim((0.5,1))
        if "eps" in metric:
            axes.set_ylim((0,10))
        axes.plot(rel['Sparsity'], rel[f'post_prune_{metric}'], label="Post-Prune")
        axes.plot(rel['Sparsity'], rel[f'post_finetune_{metric}'], label="Post-Finetune")
        axes.set_title(f"LeNet-300-100 {prune_type} Pruning Dev Acc")
        axes.legend()
        fig.savefig(f'plots_out/lenet_easy_mode/sparsity_vs_{metric}_{prune_type}_pruning.png')
