import matplotlib.pyplot as plt
import pandas as pd
import numpy as np
from matplotlib import rc

# Use LaTeX for text rendering (comment out if you don't have LaTeX installed)
rc("font", **{"family": "serif", "serif": ["Computer Modern"]})
rc("text", usetex=True)

# Load log data from multiple runs
log_files = ["../real-train/1.txt", "../real-train/seed_42.txt", "../real-train/seed_43.txt", "../real-train/seed_44.txt", "../real-train/seed_45.txt"]

# Columns expected in each log file (tab-separated, with a header row)
columns = ["iteration", "train_loss", "train_accuracy", "test_loss", "test_accuracy", "W_norm", "cos_sim"]

# Figure with two side-by-side panels
fig, axes = plt.subplots(1, 2, figsize=(15, 3.2), sharex=True)
ax_acc, ax_cos = axes

all_data = []

for log_file in log_files:
    lines = open(log_file, "r").readlines()[1:]  # Skip header
    rows = [ln.strip().split("\t") for ln in lines]

    df = pd.DataFrame(rows, columns=columns).astype({
        "iteration": int,
        "train_loss": float,
        "train_accuracy": float,
        "test_loss": float,
        "test_accuracy": float,
        "W_norm": float,
        "cos_sim": float,
    }).sort_values(by="iteration")

    all_data.append(df)

    # --- Accuracy panel: light per-run traces ---
    ax_acc.plot(df["iteration"], df["test_accuracy"], linestyle="-", alpha=0.2, color="blue")
    ax_acc.plot(df["iteration"], df["train_accuracy"], linestyle="-", alpha=0.2, color="red")

    # --- Cosine panel: light per-run traces ---
    ax_cos.plot(df["iteration"], df["cos_sim"], linestyle="-", alpha=0.2, color="green")

# Compute averages across runs on common iterations
common_iterations = sorted(set.union(*[set(d["iteration"]) for d in all_data]))
avg_test_acc = []
avg_train_acc = []
avg_cos = []

for it in common_iterations:
    test_vals = [d.loc[d["iteration"] == it, "test_accuracy"].values[0] for d in all_data if it in d["iteration"].values]
    train_vals = [d.loc[d["iteration"] == it, "train_accuracy"].values[0] for d in all_data if it in d["iteration"].values]
    cos_vals = [d.loc[d["iteration"] == it, "cos_sim"].values[0] for d in all_data if it in d["iteration"].values]

    avg_test_acc.append(np.mean(test_vals))
    avg_train_acc.append(np.mean(train_vals))
    avg_cos.append(np.mean(cos_vals))

# --- Accuracy panel: average traces ---
ax_acc.plot(common_iterations, avg_test_acc, linestyle="-", color="blue", linewidth=2, label="Test")
ax_acc.plot(common_iterations, avg_train_acc, linestyle="-", color="red", linewidth=2, label="Train")

ax_acc.set_xscale("log")
ax_acc.set_xlabel("Step", fontsize=20, labelpad=5)
ax_acc.set_ylabel("Accuracy", fontsize=20, labelpad=15)
ax_acc.set_xlim(1, 100_000)
ax_acc.tick_params(axis="both", labelsize=16)
ax_acc.grid(True)
ax_acc.legend(fontsize=16, loc="upper left")

# --- Cosine panel: average trace ---
ax_cos.plot(common_iterations, avg_cos, linestyle="-", color="green", linewidth=2, label="cos$(\\Delta \\theta_t, \\tilde{g}_t)$")

ax_cos.set_xscale("log")
ax_cos.set_xlabel("Step", fontsize=20, labelpad=5)
ax_cos.set_ylabel(r"Cosine similarity", fontsize=20, labelpad=15)
ax_cos.set_xlim(1, 200_000)
ax_cos.tick_params(axis="both", labelsize=16)
ax_cos.grid(True)
ax_cos.legend(fontsize=16, loc="upper left")

plt.tight_layout(pad=0, w_pad=4.0)

# Save the plot to a PDF file
plt.savefig("real_train_acc_prec_and_cos_plot.pdf")
