import torch
import matplotlib.pyplot as plt
import numpy as np
from pathlib import Path
from matplotlib.patches import Patch

# -------------------------------------------------------------------
# 1.  Configuration
# -------------------------------------------------------------------
sparsities = [80, 90, 95, 99]      # sparsity percentages
k = [0.2, 0.1, 0.05, 0.01]         # corresponding k values
num_epochs = 20                   # total epochs saved
model ='resnet18'                # model name 
prune = 'harp_prune'             # pruning method
dataset = 'cifar10'             # dataset name
base_dir = Path(f"/path/to/trained_models/appendix/{dataset}/{model}") 


# -------------------------------------------------------------------
# 2.  Helpers
# -------------------------------------------------------------------
def load_masks(folder: Path, epochs: int):
    """Load boolean masks_epoch_1 … masks_epoch_{epochs} from folder."""
    return [
        torch.load(folder / f"masks_epoch_{e}.pth")
             .bool()
             .cpu()
        for e in range(1, epochs + 1)
    ]

def normalized_hamming(a: torch.Tensor, b: torch.Tensor) -> float:
    """Return fraction of differing bits between two same-shape bool masks."""
    return float((a ^ b).float().mean().item())

# -------------------------------------------------------------------
# 3.  Compute and Plot Δ Hamming Curves with Annotation
# -------------------------------------------------------------------
fig, ax = plt.subplots(figsize=(6, 4))

# X-axis epochs (2..num_epochs)
epochs = np.arange(1, num_epochs)

for idx, p in enumerate(sparsities):
    # Load ADV and AWP masks
    adv_masks = load_masks(base_dir / f"adv/{p}/{prune}/latest_exp", num_epochs)
    awp_masks = load_masks(base_dir / f"awp/{p}/{prune}/latest_exp", num_epochs)
    
    # Compute NH vs initial mask (epoch 1)
    nh_adv = np.array([normalized_hamming(adv_masks[t], adv_masks[0]) 
                       for t in range(1, num_epochs)])
    nh_awp = np.array([normalized_hamming(awp_masks[t], awp_masks[0]) 
                       for t in range(1, num_epochs)])
    
    # Δ Hamming = ADV - AWP
    delta = nh_adv - nh_awp 
    
    # Plot Δ curve
    color = f"C{idx}"
    ax.plot(epochs, delta, color=color, linestyle='--', linewidth=2,
            label=f'Sparsity={p}%')

# Draw horizontal zero line
ax.axhline(0, color='black', linewidth=0.85, alpha=0.5)
ax.axvline(5, color='black', linestyle=':', linewidth=0.85)


# Shade regions to indicate stability
ymin, ymax = ax.get_ylim()
ax.set_ylim(ymin, ymax)  # fix limits before shading
green_patch = Patch(facecolor='green', alpha=0.3, label='S2AP more stable')
red_patch   = Patch(facecolor='red',   alpha=0.3, label='Orig. more stable')
ax.axhspan(0, ymax if ymax>6e-3 else 6e-3, facecolor='green', alpha=0.1)
ax.axhspan(ymin, 0, facecolor='red',   alpha=0.1)


# Labels, ticks, grid, legend
ax.set_xlabel("Pruning Epochs", fontsize=20)
ax.set_xticks([5, 10, 15])
ax.set_yticks([0, 2e-3, 4e-3, 6e-3])
ax.tick_params(axis='y', labelrotation=45, labelsize=15)
ax.tick_params(axis='x', labelsize=15)

ax.set_ylabel("Hamming Diff.", fontsize=20)
ax.grid(True, linewidth=0.4, alpha=0.6, which="both")

# Combine curve and patch legends
handles, labels = ax.get_legend_handles_labels()
handles.extend([green_patch, red_patch])
ax.legend(handles=handles, loc='upper right', fontsize=15)

plt.tight_layout(rect=[0,0,1,1])
plt.savefig(f"plots/hamming_stability_all_sparsities_{model}_{dataset}_{prune}.pdf", dpi=300, bbox_inches='tight', pad_inches=0.01)
