# plot_regression_args.py
import os
import argparse
import numpy as np
import matplotlib.pyplot as plt
from tensorboard.backend.event_processing import event_accumulator

# =====================
# Args
# =====================
parser = argparse.ArgumentParser()
parser.add_argument("--runs-dir", type=str, default="runs/regression/", help="Directory containing runs")
parser.add_argument("--plot-dir", type=str, default=None, help="Directory to save plots (defaults to runs-dir)")
parser.add_argument("--smooth-sigma", type=float, default=2, help="Gaussian smoothing sigma")
parser.add_argument("--ci-window", type=int, default=10, help="Window size for local CI")
parser.add_argument("--ci-scale", type=float, default=1.0, help="Scale for CI band (e.g. 1=±1 std)")
args = parser.parse_args()

RUNS_DIR = args.runs_dir
PLOT_DIR = args.plot_dir or RUNS_DIR
SMOOTH_SIGMA = args.smooth_sigma
CI_WINDOW = args.ci_window
CI_SCALE = args.ci_scale

# =====================
# Config
# =====================
DATASETS = ["friedman", "housing"]

METHODS = {
    "standard": "Uniform Small",
    "largelr": "Uniform Large",
    "twotimescale": "Non-Uniform",
}

# =====================
# Run-name parser
# =====================
def parse_run_name(run_name):
    parts = run_name.split("__")
    if len(parts) < 4:
        return None, None
    dataset, exp_name = parts[0], parts[1]
    if dataset not in DATASETS or exp_name not in METHODS:
        return None, None
    return dataset, exp_name

# =====================
# TensorBoard loader
# =====================
def load_mse(run_path):
    ea = event_accumulator.EventAccumulator(run_path, size_guidance={"scalars": 0})
    ea.Reload()
    if "eval/mse" not in ea.Tags()["scalars"]:
        return None, None
    events = ea.Scalars("eval/mse")
    steps = np.array([e.step for e in events])
    values = np.array([e.value for e in events])
    return steps, values

# =====================
# Gaussian smoothing
# =====================
def gaussian_smooth(x, sigma):
    radius = int(4 * sigma + 0.5)
    t = np.arange(-radius, radius + 1)
    kernel = np.exp(-0.5 * (t / sigma) ** 2)
    kernel /= kernel.sum()
    x_pad = np.pad(x, pad_width=radius, mode="reflect")
    return np.convolve(x_pad, kernel, mode="valid")

# =====================
# Rolling std (single-trajectory CI)
# =====================
def rolling_std(x, window):
    out = np.zeros_like(x)
    for i in range(len(x)):
        lo = max(0, i - window)
        hi = min(len(x), i + window + 1)
        out[i] = np.std(x[lo:hi])
    return out

# =====================
# Load all runs
# =====================
data = {dataset: {m: [] for m in METHODS} for dataset in DATASETS}

for run_name in os.listdir(RUNS_DIR):
    run_path = os.path.join(RUNS_DIR, run_name)
    if not os.path.isdir(run_path):
        continue
    dataset, method = parse_run_name(run_name)
    if method is None:
        continue
    steps, mse = load_mse(run_path)
    if steps is None:
        continue
    data[dataset][method].append((steps, mse))

# =====================
# Plot
# =====================
fig, axes = plt.subplots(1, len(DATASETS), figsize=(12, 4), sharey=False)

for ax, dataset in zip(axes, DATASETS):
    upper_lim = -np.inf
    lower_lim = np.inf
    for method, label in METHODS.items():
        runs = data[dataset][method]
        if len(runs) == 0:
            continue

        # single trajectory
        steps, values = runs[0]
        values = values.astype(float)

        mean = gaussian_smooth(values, SMOOTH_SIGMA)
        std = rolling_std(values, CI_WINDOW)

        ax.plot(steps[:len(mean)], mean, label=label)
        ax.fill_between(
            steps[:len(mean)],
            mean - CI_SCALE * std[:len(mean)],
            mean + CI_SCALE * std[:len(mean)],
            alpha=0.25,
        )
        upper_lim = max(upper_lim, mean.max()) * 1.05
        lower_lim = min(lower_lim, mean.min()) * 0.95

    ax.set_title(dataset.upper(), fontsize=20)
    ax.set_xlabel("Epoch", fontsize=16)
    ax.set_yscale("log")
    ax.set_ylim(bottom=lower_lim, top=upper_lim)
    ax.grid(True)

axes[-1].legend(fontsize=16)
axes[0].set_ylabel("Mean Squared Error (Log Scale)", fontsize=16)

plt.tight_layout()
os.makedirs(PLOT_DIR, exist_ok=True)
plot_file = os.path.join(PLOT_DIR, "regression.pdf")
plt.savefig(plot_file)
print(f"Plot saved to {plot_file}")
