#!/usr/bin/env python3

import argparse
import os
import re
import warnings
from pathlib import Path
import numpy as np
import matplotlib.pyplot as plt
from matplotlib.ticker import FuncFormatter

from tensorboard.backend.event_processing.event_accumulator import EventAccumulator

# --------------------------------------------------------------------------------------
# Defaults
# --------------------------------------------------------------------------------------

DEFAULT_DEXGYM_GAMES = [
    "PenCatchOverarm-v0",
    "PenCatchUnderarm-v0",
    "BlockCatchOverarm-v0",
    "BlockCatchUnderarm-v0",
    "EggCatchOverarm-v0",
    "EggCatchUnderarm-v0",
]

DEXGYM_ALGOS_DEFAULT = [
    "sac_lle",
    "sac",
    "sac_recon-recon",
    "sac_recon-reward",
    "sac_recon-next_state",
    "sac_joint_lle",
    "sac_spr",
    "sac_dbc"
]

ALGO_DISPLAY = {
    "sac_lle": ("SAC-LLE", "#1f77b4"),
    "sac": ("SAC", "#7f7f7f"),
    "sac_recon-recon": ("SAC-Recon", "#d62728"),
    "sac_recon": ("SAC-Recon", "#d62728"),
    "sac_recon-reward": ("SAC-Reward", "#ff9896"),
    "sac_reward": ("SAC-Reward", "#ff9896"),
    "sac_recon-next_state": ("SAC-Next", "#2ca02c"),
    "sac_next_state": ("SAC-Next", "#2ca02c"),
    "sac_joint_lle": ("SAC-LLE (Joint)", "#17becf"),
    "sac_spr": ("SAC-SPR", "#9467bd"),
    "sac_dbc": ("SAC-DBC", "#8c564b"),
}

ALGO_ORDER = [
    "sac_lle",
    "sac",
    "sac_recon-recon",
    "sac_recon",
    "sac_recon-reward",
    "sac_reward",
    "sac_recon-next_state",
    "sac_next_state",
    "sac_joint_lle",
    "sac_spr",
]

# --------------------------------------------------------------------------------------
# Utilities
# --------------------------------------------------------------------------------------

def load_scalar_from_events(path, tag="charts/episodic_return", max_steps=None):
    try:
        acc = EventAccumulator(path, size_guidance={"scalars": 0})
        acc.Reload()
        scalars = acc.Scalars(tag)
    except Exception:
        return None

    steps = np.array([x.step for x in scalars], dtype=np.float32)
    values = np.array([x.value for x in scalars], dtype=np.float32)

    if max_steps is not None:
        mask = steps <= max_steps
        steps = steps[mask]
        values = values[mask]
    return steps, values


def smooth(y, radius=10):
    if len(y) == 0 or radius <= 1:
        return y
    window = np.ones(radius, dtype=np.float32)
    y = np.asarray(y, dtype=np.float32)
    smoothed = np.convolve(y, window, mode="same")
    counts = np.convolve(np.ones_like(y), window, mode="same")
    return smoothed / counts


def million_formatter():
    return FuncFormatter(lambda value, _: f"{value / 1e6:.0f}")


def aggregate_runs(run_paths, max_steps, smooth_radius):
    curves = []
    for path in run_paths:
        print(f"  Loading run: {path}")
        tb_files = []
        for root, _, files in os.walk(path):
            for f in files:
                if "events.out.tfevents" in f:
                    tb_files.append(os.path.join(root, f))
        if not tb_files:
            continue
        tb_files.sort(key=os.path.getmtime, reverse=True)
        latest = tb_files[0]
        data = load_scalar_from_events(latest, max_steps=max_steps)
        if data is None:
            continue
        steps, rewards = data
        curves.append((steps, rewards))
    if not curves:
        return None
    min_len = min(len(v[0]) for v in curves)
    steps = curves[0][0][:min_len]
    stacked = np.stack([smooth(v[1][:min_len], radius=smooth_radius) for v in curves], axis=0)
    mean = stacked.mean(axis=0)
    std_error = stacked.std(axis=0) / np.sqrt(len(curves))
    return steps, mean, std_error


# --------------------------------------------------------------------------------------
# Plotting
# --------------------------------------------------------------------------------------


def select_run_paths(log_dir, game, algo, suite):
    """
    Return the most recent run directory per seed for the specified game/algorithm.
    """
    suite_token = f"_{suite}"
    matches = []
    for d in os.listdir(log_dir):
        if game not in d or suite_token not in d:
            continue
        if "__v1_" not in d:
            continue
        try:
            algo_token = d.split("__v1_", 1)[1].split(suite_token, 1)[0].rstrip("_")
        except ValueError:
            continue
        if algo_token != algo:
            continue
        matches.append(os.path.join(log_dir, d))
    if not matches:
        return []

    seed_pattern = re.compile(r"seed(\d+)")
    latest_per_seed = {}
    for path in matches:
        name = os.path.basename(path)
        seed_match = seed_pattern.search(name)
        seed = seed_match.group(1) if seed_match else name
        mtime = os.path.getmtime(path)
        prev = latest_per_seed.get(seed)
        if prev is None or mtime > prev[0]:
            latest_per_seed[seed] = (mtime, path)

    def seed_key(item):
        seed = item[0]
        try:
            return int(seed)
        except ValueError:
            return seed

    ordered = [entry[1][1] for entry in sorted(latest_per_seed.items(), key=seed_key)]
    return ordered


def plot_suite(log_dir, games, algorithms, max_steps, suite, smooth_radius):
    for style_name in ("seaborn-v0_8", "seaborn"):
        try:
            plt.style.use(style_name)
            break
        except OSError:
            continue
    else:
        warnings.warn("Seaborn Matplotlib styles not available; falling back to defaults.")

    plt.rcParams.update({
        "font.size": 18,
        "axes.titlesize": 24,
        "axes.labelsize": 22,
        "xtick.labelsize": 20,
        "ytick.labelsize": 20,
        "legend.fontsize": 20,
    })

    num_games = len(games)
    if num_games == 0:
        raise ValueError("No games provided for plotting.")

    if num_games <= 3:
        nrows, ncols = 1, num_games
    else:
        nrows = int(np.sqrt(num_games))
        ncols = int(np.ceil(num_games / nrows))
        if (nrows + 1) * (ncols - 1) >= num_games and (nrows + 1) * (ncols - 1) < nrows * ncols:
            nrows += 1
            ncols = max(1, ncols - 1)

    fig_width = 8 * ncols
    fig_height = 5 * nrows + 1.5
    fig = plt.figure(figsize=(fig_width, fig_height), dpi=300)
    axes = [fig.add_subplot(nrows, ncols, i + 1) for i in range(num_games)]

    colormap = plt.get_cmap("tab20")
    formatter = million_formatter()

    ordered_algorithms = [algo for algo in ALGO_ORDER if algo in algorithms]
    ordered_algorithms += [algo for algo in algorithms if algo not in ordered_algorithms]
    algo_colors = {
        algo: colormap(idx % colormap.N)
        for idx, algo in enumerate(ordered_algorithms)
    }

    all_handles = []
    all_labels = []

    for idx, game in enumerate(games):
        print(f"Processing game {idx + 1}/{num_games}: {game}")
        ax = axes[idx]
        ax.set_title(game.replace("-", " "), pad=20)
        ax.set_xlabel("Environment Steps", labelpad=10)
        ax.set_ylabel("Average Return", labelpad=10)
        ax.grid(True, linestyle="--", alpha=0.7)
        ax.set_xlim(0, max_steps)
        ax.set_xticks(np.linspace(0, max_steps, 6))
        ax.xaxis.set_major_formatter(formatter)

        for algo in ordered_algorithms:
            if algo not in algorithms:
                continue
            print(f"    Algorithm: {algo}")
            run_paths = select_run_paths(log_dir, game, algo, suite)
            data = aggregate_runs(run_paths, max_steps, smooth_radius)
            if data is None:
                print("      No valid runs found.")
                continue
            steps, mean, std_error = data
            label, _ = ALGO_DISPLAY.get(algo, (algo, None))
            color = algo_colors.get(algo, colormap(0))

            line, = ax.plot(
                steps,
                mean,
                color=color,
                linewidth=2.0,
                label=label,
            )
            ax.fill_between(
                steps,
                mean - std_error,
                mean + std_error,
                color=color,
                alpha=0.2,
            )

            if idx == 0 and label not in all_labels:
                all_handles.append(line)
                all_labels.append(label)

    legend_ncol = 1
    base_margin = 0.08
    row_adjust = max(nrows - 1, 0) * 0.015
    bottom_margin = max(0.05, base_margin - row_adjust)
    if all_labels:
        legend_ncol = len(all_labels)
        bottom_margin = max(bottom_margin, 0.05 + 0.002 * len(all_labels) - row_adjust)

    fig.subplots_adjust(
        wspace=0.45,
        hspace=0.45,
        bottom=bottom_margin,
        top=0.92,
        left=0.1,
        right=0.97,
    )

    if all_handles:
        legend_anchor = -(bottom_margin - 0.015)
        fig.legend(
            all_handles,
            all_labels,
            loc="upper center",
            ncol=legend_ncol,
            bbox_to_anchor=(0.5, legend_anchor),
            frameon=True,
            framealpha=0.9,
            edgecolor="black",
            borderaxespad=0.2,
        )

    return fig


# --------------------------------------------------------------------------------------
# Main
# --------------------------------------------------------------------------------------

def main():
    parser = argparse.ArgumentParser(description="Plot DexGym or Robosuite training curves.")
    parser.add_argument("--suite", choices=["dexgym", "robosuite"], default="dexgym")
    parser.add_argument("--log_dir", default=None, help="Path to runs directory.")
    parser.add_argument("--games", nargs="*", default=None, help="Environments to plot.")
    parser.add_argument("--algorithms", nargs="*", default=None, help="Algorithms to plot.")
    parser.add_argument("--max_timesteps", type=int, default=None, help="Limit timesteps for x-axis.")
    parser.add_argument("--smooth_radius", type=int, default=300, help="Moving-average radius for smoothing rewards.")
    parser.add_argument("--output_dir", default="plots", help="Directory to store generated plots.")
    parser.add_argument("--filename", default=None, help="Optional custom filename for the saved plot.")
    args = parser.parse_args()

    if args.suite == "dexgym":
        default_algos = DEXGYM_ALGOS_DEFAULT
        default_games = DEFAULT_DEXGYM_GAMES
        default_dir = "runs-dexgym/"
        default_steps = int(5e6)
    else:
        default_algos = [
            "sac",
            "sac_recon-recon",
            "sac_recon-reward",
            "sac_recon-next_state",
        ]
        default_games = args.games or []
        default_dir = "runs-robosuite/"
        default_steps = int(5.5e5)

    log_dir = args.log_dir or default_dir
    games = args.games or default_games
    algorithms = args.algorithms or default_algos
    max_steps = args.max_timesteps or default_steps

    if not os.path.isdir(log_dir):
        raise FileNotFoundError(f"Log directory '{log_dir}' does not exist.")

    if not games:
        raise ValueError("No games specified. Provide --games for robosuite plots.")

    print(f"Building plot for suite '{args.suite}' with {len(games)} games and {len(algorithms)} algorithms.")
    fig = plot_suite(log_dir, games, algorithms, max_steps, args.suite, args.smooth_radius)

    output_dir = Path(args.output_dir)
    output_dir.mkdir(parents=True, exist_ok=True)
    filename = args.filename or f"{args.suite}.pdf"
    output_path = output_dir / filename
    fig.savefig(output_path, bbox_inches="tight")
    print(f"Saved plot to {output_path}")


if __name__ == "__main__":
    main()
