import os
import re
import numpy as np
import matplotlib.pyplot as plt
from tensorboard.backend.event_processing import event_accumulator

RUNS_DIR = "runs/ALE/"

ENVS = [
    "Breakout-v5",
    "Pong-v5",
    "SpaceInvaders-v5",
    "Qbert-v5",
]

METHODS = {
    "DQL": "Uniform Small",
    "DQL-largelr": "Uniform Large",
    "DRQL": "Non-Uniform",
}

CI_WINDOW = 50     # rolling window size (in timesteps). Tune this.
CI_ALPHA = 0.05    # 0.05 -> 95% CI
Z95 = 1.96         # approximate z for 95% CI (normal approx). If you want t-based, replace.

def rolling_ci_single_trajectory(y, window=CI_WINDOW, alpha=CI_ALPHA):
    """
    Given a single 1D trajectory y (shape (T,)), compute rolling mean and
    (1-alpha) confidence interval using a normal approximation.
    Returns: mean, lower, upper (all shape (T,))
    """
    y = np.asarray(y, dtype=float)
    T = len(y)
    mean = np.zeros(T)
    lower = np.zeros(T)
    upper = np.zeros(T)

    # choose z for the given alpha (two-sided)
    # For simplicity we use normal approx. For small windows you may use t-quantile.
    if alpha == 0.05:
        z = Z95
    else:
        # simple approximate: for other alphas, use normal quantile
        from math import erf, sqrt
        # convert alpha to z: p = 1-alpha/2 -> z = sqrt(2)*erf_inv(2p-1)
        p = 1 - alpha / 2
        # approximate inverse erf via simpler transform (or import scipy if available)
        # fallback to 1.96 for other alphas to avoid extra deps
        z = Z95

    for t in range(T):
        start = max(0, t - window + 1)
        segment = y[start:t+1]
        m = segment.mean()
        s = segment.std(ddof=1) if len(segment) > 1 else 0.0
        mean[t] = m
        half_width = z * s / np.sqrt(len(segment)) if len(segment) > 0 else 0.0
        lower[t] = m - half_width
        upper[t] = m + half_width

    return mean, lower, upper


def parse_run_name(run_name):
    """
    Example:
    Breakout-v5__DQL_Breakout-v5__1__1765440156
    Breakout-v5__DQL_Breakout-v5-largelr__1__...
    Breakout-v5__DRQL_Breakout-v5__1__...
    """
    env = run_name.split("__")[0]

    if "DRQL" in run_name:
        method = "DRQL"
    elif "largelr" in run_name:
        method = "DQL-largelr"
    else:
        method = "DQL"

    return env, method


def load_episodic_returns(run_path):
    ea = event_accumulator.EventAccumulator(
        run_path,
        size_guidance={"scalars": 0}
    )
    ea.Reload()

    if "charts/episodic_return" not in ea.Tags()["scalars"]:
        return None, None

    events = ea.Scalars("charts/episodic_return")
    steps = np.array([e.step for e in events])
    returns = np.array([e.value for e in events])

    return steps, returns




def gaussian_smooth(x, sigma):
    """
    Gaussian smoothing via convolution with edge correction (reflect padding).

    Args:
        x: 1D array (trajectory)
        sigma: standard deviation of Gaussian (in index units)

    Returns:
        Smoothed array, same length as x
    """
    x = np.asarray(x, dtype=float)
    radius = int(4 * sigma + 0.5)
    t = np.arange(-radius, radius + 1)
    kernel = np.exp(-0.5 * (t / sigma) ** 2)
    kernel /= kernel.sum()

    # Reflect padding at edges
    x_pad = np.pad(x, pad_width=radius, mode='reflect')
    y = np.convolve(x_pad, kernel, mode='valid')

    return y

# --------------------------------------------------
# Load all runs
# --------------------------------------------------
data = {
    env: {m: [] for m in METHODS} for env in ENVS
}

for run_name in os.listdir(RUNS_DIR):
    run_path = os.path.join(RUNS_DIR, run_name)
    if not os.path.isdir(run_path):
        continue

    env, method = parse_run_name(run_name)
    if env not in ENVS:
        continue

    steps, returns = load_episodic_returns(run_path)
    if steps is None:
        continue

    data[env][method].append((steps, returns))


# --------------------------------------------------
# Plot
# --------------------------------------------------
fig, axes = plt.subplots(1, len(ENVS), figsize=(22, 4), sharey=False)

for ax, env in zip(axes, ENVS):
    for method in METHODS:
        method_name = METHODS[method]
        runs = data[env][method]
        if len(runs) == 0:
            continue

        min_len = min(len(r[0]) for r in runs)
        steps = runs[0][0][:min_len]
        values = np.stack([r[1][:min_len] for r in runs])

        # compute mean across runs (single mean trajectory)
        mean = values.mean(axis=0)

        # Compute rolling CI on the single mean trajectory
        mean_roll, lo_roll, hi_roll = rolling_ci_single_trajectory(mean, window=CI_WINDOW, alpha=CI_ALPHA)

        # Smooth the mean and CI to match your previous smoothing
        mean_sm = gaussian_smooth(mean_roll, sigma=20)
        lo_sm = gaussian_smooth(lo_roll, sigma=20)
        hi_sm = gaussian_smooth(hi_roll, sigma=20)

        ax.plot(steps, mean_sm, label=method_name)
        ax.fill_between(steps, lo_sm, hi_sm, alpha=0.25)

    ax.set_title(env, fontsize=16)
    ax.set_xlabel("Environment Steps", fontsize=14)
    ax.grid(True)

axes[-1].legend(loc="best", fontsize=14)
axes[0].set_ylabel("Cumulative Mean Episodic Return", fontsize=14)

plt.tight_layout()
# plt.show()
plot_file = os.path.join(RUNS_DIR, "Qlearning.pdf")
plt.savefig(plot_file)