import torch
import matplotlib.pyplot as plt
from matplotlib.ticker import ScalarFormatter
import numpy as np


D = torch.load('Times.pt', weights_only=False)

N_perturbed_list = np.floor(np.array([ 0.01, 0.03, 0.06]) * 304 * 304).astype(int)


D = torch.load('Times.pt', weights_only = False)

avg_times = D['Tbs']

# --- Plotting ---
plt.figure(figsize=(6, 4))
plt.plot(N_perturbed_list, avg_times, marker='o', linestyle='-', color='tab:blue', linewidth=2, markersize=6, label = 'UNet2 for OCTA-500')

xtick_labels = [f"{d} x 1" for d in N_perturbed_list]


plt.xlabel("Perturbation Dimesnion ($r$)", fontsize=12)
plt.ylabel("Average Verification Runtime (seconds)", fontsize=12)
plt.ylim(300, 1400)
plt.legend(fontsize=10)
plt.gca().yaxis.set_major_formatter(ScalarFormatter(useOffset=False))
plt.title("Average Verification Runtime vs Perturbation Dimesnion", fontsize=13)
plt.grid(True, linestyle='--', linewidth=0.5, alpha=0.7)
plt.xticks(N_perturbed_list, xtick_labels, fontsize=10)
plt.yticks(fontsize=10)
plt.tight_layout()
plt.savefig("Runtim_vs_PerturbationDim.png", dpi=300)  # Save high-res for paper
plt.savefig("Runtime_vs_PerturbationDim.eps", format='eps')
plt.show()
