import matplotlib.pyplot as plt
import pandas as pd
import numpy as np
from matplotlib import rc

# Use LaTeX for text rendering
# (comment out if you don't have LaTeX installed)
rc("font", **{"family": "serif", "serif": ["Computer Modern"]})
rc("text", usetex=True)

# Load log data from multiple runs
log_files = ["../E1/log.txt", "../E2/log.txt", "../E3/log.txt", "../E4/log.txt", "../E5/log.txt"]
columns = ["iteration", "train_loss", "train_accuracy", "test_loss", "test_accuracy", "W1_norm", "W2_norm"]

plt.figure(figsize=(8, 3))

all_data = []

for log_file in log_files:
    data = pd.read_csv(log_file, names=columns)
    data = data.sort_values(by="iteration")
    all_data.append(data)

    # Plot individual runs with the same blue color and transparency
    plt.plot(data["iteration"], data["test_accuracy"], linestyle="-", alpha=0.2, color="blue", label="Single Run")

# Compute average test accuracy at each iteration across runs
common_iterations = sorted(set.union(*[set(df["iteration"]) for df in all_data]))
avg_test_accuracy = []

for iteration in common_iterations:
    values = [df[df["iteration"] == iteration]["test_accuracy"].values[0] for df in all_data if iteration in df["iteration"].values]
    avg_test_accuracy.append(np.mean(values))

# Plot average test accuracy with full opacity in blue
plt.plot(common_iterations, avg_test_accuracy, linestyle="-", color="blue", linewidth=2, label="Average")

plt.xscale("log")
plt.xlabel("Step", fontsize=20, labelpad=5)
plt.ylabel("Accuracy", fontsize=20, labelpad=15)
plt.yticks(fontsize=16)
plt.xticks(fontsize=16)
# plt.title("Weight Norm Minimization within the Zero-Loss Set")
plt.legend(
    handles=[
        plt.Line2D([0], [0], color="blue", lw=2, label="Average Test Accuracy"),
        # plt.Line2D([0], [0], color="blue", lw=2, alpha=0.3, label="Individual run"),
    ],
    fontsize=16,
)
plt.grid()
plt.tight_layout()

# Save the plot to a PDF file
plt.savefig("test_accuracy_plot.pdf")
