import os
import numpy as np
import matplotlib.pyplot as plt
from tensorboard.backend.event_processing import event_accumulator


# =====================
# Config
# =====================
RUNS_DIR = "runs/classification/best/"

DATASETS = ["mnist", "fashion", "cifar"]

METHODS = {
    "standard": "Uniform Small",
    "largelr": "Uniform Large",
    "twotimescale": "Non-Uniform",
}

SMOOTH_SIGMA = 2


# =====================
# Run-name parser
# =====================
def parse_run_name(run_name):
    """
    Expected:
    dataset__exp_name__seedX__timestamp
    """
    parts = run_name.split("__")
    if len(parts) < 4:
        return None, None

    dataset, exp_name = parts[0], parts[1]

    if dataset not in DATASETS:
        return None, None
    if exp_name not in METHODS:
        return None, None

    return dataset, exp_name


# =====================
# TensorBoard loader
# =====================
def load_accuracy(run_path):
    ea = event_accumulator.EventAccumulator(
        run_path,
        size_guidance={"scalars": 0}
    )
    ea.Reload()

    if "eval/accuracy" not in ea.Tags()["scalars"]:
        return None, None

    events = ea.Scalars("eval/accuracy")
    steps = np.array([e.step for e in events])
    values = np.array([e.value for e in events])

    return steps, values


# =====================
# Gaussian smoothing
# =====================
def gaussian_smooth(x, sigma):
    x = np.asarray(x, dtype=float)
    radius = int(4 * sigma + 0.5)
    t = np.arange(-radius, radius + 1)
    kernel = np.exp(-0.5 * (t / sigma) ** 2)
    kernel /= kernel.sum()

    x_pad = np.pad(x, pad_width=radius, mode="reflect")
    return np.convolve(x_pad, kernel, mode="valid")


# =====================
# Load all runs
# =====================
data = {
    dataset: {m: [] for m in METHODS}
    for dataset in DATASETS
}

for run_name in os.listdir(RUNS_DIR):
    run_path = os.path.join(RUNS_DIR, run_name)
    if not os.path.isdir(run_path):
        continue

    dataset, method = parse_run_name(run_name)
    if method is None:
        continue

    steps, acc = load_accuracy(run_path)
    if steps is None:
        continue

    data[dataset][method].append((steps, acc))


# =====================
# Plot (1×3) with single-trajectory CI
# =====================
CI_SCALE = 1.0   # tighten if needed (e.g. 0.5)

fig, axes = plt.subplots(1, 3, figsize=(12, 4), sharey=False)

for ax, dataset in zip(axes, DATASETS):
    for method, label in METHODS.items():
        runs = data[dataset][method]
        if len(runs) == 0:
            continue

        # ---- pick ONE trajectory (first run)
        steps, y = runs[0]
        y = y.astype(float)

        # ---- smooth mean
        mean = gaussian_smooth(y, SMOOTH_SIGMA)

        # ---- estimate single-trajectory CI
        residual = y - mean
        var = gaussian_smooth(residual ** 2, SMOOTH_SIGMA)
        std = np.sqrt(np.maximum(var, 1e-12))

        lower = mean - CI_SCALE * std
        upper = mean + CI_SCALE * std

        ax.plot(steps, mean, label=label)
        ax.fill_between(
            steps,
            lower,
            upper,
            alpha=0.2,
            linewidth=0
        )

    ax.set_title(dataset.upper(), fontsize=20)
    ax.set_xlabel("Epoch", fontsize=16)
    ax.grid(True)

axes[0].set_ylabel("Classification Accuracy", fontsize=16)
axes[-1].legend(fontsize=16)

plt.tight_layout()

plot_file = os.path.join(RUNS_DIR, "classification.pdf")
plt.savefig(plot_file)
