import pandas as pd
import matplotlib.pyplot as plt
from pathlib import Path
import numpy as np

dataset   = 'cifar10'
model     = 'wrn_28_4'
prune     = 'harp_prune'
base_dir   = Path(f'/path/to/{dataset}/{model}')
sparsities = [80, 90, 95, 99]

def warp_epoch(e):
    if e <= 10:
        return 1 + (e-1)/9
    else:
        return 3 + (e-11)

# 1×4 grid now
fig, axes = plt.subplots(1, 4, figsize=(12, 3), sharey=True)
ks = [0.2, 0.1, 0.05, 0.01]

for ax, spars, kval in zip(axes, sparsities, ks):
    adv = (pd.read_csv(base_dir / f'adv/{spars}/{prune}/eigenvalues.csv')
             .groupby('epoch')['top_eigenvalue'].mean())
    awp = (pd.read_csv(base_dir / f'awp/{spars}/{prune}/eigenvalues.csv')
             .groupby('epoch')['top_eigenvalue'].mean())

    xs_adv = [warp_epoch(e) for e in adv.index]
    line1, = ax.plot(xs_adv, adv.values, '-o', markersize=4)
    xs_awp = [warp_epoch(e) for e in awp.index]
    line2, = ax.plot(xs_awp, awp.values, '--s', markersize=4)

    ax.text(0.5, 0.92, f'Sparsity={spars}%', transform=ax.transAxes,
            ha='center', va='top', fontsize=14)
    ax.grid(alpha=0.3)

    # show ticks on every subplot
    tick_es = [0, 10, 13, 16, 19]
    ax.set_xticks([warp_epoch(e) for e in tick_es])
    ax.set_xticklabels([str(e) for e in tick_es], fontsize=12)

# shared labels & legend
fig.text(0.5, 0.01, 'Pruning Epochs', ha='center', fontsize=15)
fig.text(0.01, 0.5, r'$\lambda_{\max}$', va='center', rotation='vertical', fontsize=17)
fig.legend([line1, line2], ['Original', 'S2AP'],
           loc='upper center', ncol=2, framealpha=0.3,
           bbox_to_anchor=(0.5, 1.05), fontsize=13)

plt.tight_layout(rect=[0.03, 0.03, 1, 0.95])
plt.savefig(f"plots/eigen_small_multiples_1x4_{model}_{dataset}_{prune}.pdf", dpi=300,
            bbox_inches='tight', pad_inches=0.01)
plt.show()
