import os
import re

import matplotlib.pyplot as plt

log_file_path = "path/file.log"

epochs = []
test_auc = []

with open(log_file_path, "r") as file:
    log_content = file.read()

epoch_pattern = re.compile(r"'epoch': (\d+)")
test_auc_pattern = re.compile(r"test:.*'auc': ([0-9.]+)")

for epoch_match, auc_match in zip(
    re.finditer(epoch_pattern, log_content),
    re.finditer(test_auc_pattern, log_content),
):
    epoch = int(epoch_match.group(1))
    auc = float(auc_match.group(1))
    epochs.append(epoch)
    test_auc.append(auc)

max_auc = max(test_auc)
max_epoch = epochs[test_auc.index(max_auc)]

plt.figure(figsize=(10, 6))
plt.plot(epochs, test_auc, label="Test AUC", marker="o")
plt.xlabel("Epoch")
plt.ylabel("Test AUC")
plt.title("Test AUC vs Epoch")
plt.legend()
plt.grid(True)

plt.annotate(
    f"Max AUC: {max_auc:.4f}",
    xy=(max_epoch, max_auc),
    xytext=(max_epoch, max_auc + 0.02),
    arrowprops=dict(facecolor="black", shrink=0.05),
    fontsize=12,
    color="red",
)

output_filename = (
    os.path.splitext(os.path.basename(log_file_path))[0] + "_auc_plot.png"
)
plt.savefig(output_filename)

plt.show()
