import torch
import matplotlib.pyplot as plt
from matplotlib.ticker import ScalarFormatter
import numpy as np

D = torch.load('Results.pt', weights_only=False)
RV_bars = D['RV_bars']
# IoU_bars = D['IoU_bars']


N_perturbed_list = np.floor(np.array([ 0.01, 0.03, 0.06]) * 304 * 304).astype(int)


# --- Plotting ---
plt.figure(figsize=(6, 4))
plt.plot(N_perturbed_list, RV_bars, marker='o', linestyle='-', color='tab:blue', linewidth=2, markersize=6, label = 'UNet2 for OCTA-500')

xtick_labels = [f"{d} x 1" for d in N_perturbed_list]

plt.xlabel("perturbation dimension (# perturbed pixels ($r$))", fontsize=12)
plt.ylabel("Average Robustness Value ($\\overline{\\mathbf{RV}}$)", fontsize=12)
plt.legend(fontsize=10)
plt.gca().yaxis.set_major_formatter(ScalarFormatter(useOffset=False))
plt.title("Average Robustness Value vs Perturbation Dimension", fontsize=13)
plt.grid(True, linestyle='--', linewidth=0.5, alpha=0.7)
plt.xticks(N_perturbed_list, xtick_labels, fontsize=10)
plt.yticks(fontsize=10)
plt.tight_layout()
plt.savefig("Robustness_vs_PerturbationDim.png", dpi=300)  # Save high-res for paper
plt.savefig("Robustness_vs_PerturbationDim.eps", format='eps')
plt.show()