import os
import numpy as np
import matplotlib.pyplot as plt
from tensorboard.backend.event_processing import event_accumulator

plt.rcParams.update({
    "font.size": 11,
    "axes.labelsize": 12,
    "axes.titlesize": 13,
    "legend.fontsize": 10,
    "figure.dpi": 100,
})


def load_single_log(log_path, tag="charts/episodic_return"):
    ea = event_accumulator.EventAccumulator(log_path)
    ea.Reload()
    if tag not in ea.Tags().get("scalars", []):
        return None, None
    events = ea.Scalars(tag)
    steps = np.array([e.step for e in events])
    values = np.array([e.value for e in events])
    return steps, values


def load_all_seeds(env_dir, tag="charts/episodic_return"):
    all_data = []
    event_files = sorted([f for f in os.listdir(env_dir) if f.startswith("events")])

    for ef in event_files:
        log_path = os.path.join(env_dir, ef)
        steps, values = load_single_log(log_path, tag)
        if steps is not None and len(steps) > 100:
            all_data.append((steps, values))

    return all_data


def interpolate_to_common_steps(all_data, num_points=500):
    if not all_data:
        return None, None, None

    max_step = min(data[0][-1] for data in all_data)
    common_steps = np.linspace(0, max_step, num_points)

    interpolated = []
    for steps, values in all_data:
        interp_values = np.interp(common_steps, steps, values)
        interpolated.append(interp_values)

    interpolated = np.array(interpolated)
    mean = np.mean(interpolated, axis=0)
    std = np.std(interpolated, axis=0)

    return common_steps, mean, std


def smooth(values, weight=0.9):
    smoothed = np.zeros_like(values)
    smoothed[0] = values[0]
    for i in range(1, len(values)):
        smoothed[i] = weight * smoothed[i - 1] + (1 - weight) * values[i]
    return smoothed


def plot_single_env(ax, base_dir, env_name, title):
    methods = [
        ("td3_logs", "AMS-TD3", "#2196F3"),
        ("sac_logs", "AMS-SAC", "#4CAF50"),
    ]

    for method_dir, label, color in methods:
        env_dir = os.path.join(base_dir, method_dir, env_name)
        if not os.path.exists(env_dir):
            continue

        all_data = load_all_seeds(env_dir)
        if not all_data:
            continue

        steps, mean, std = interpolate_to_common_steps(all_data)
        if steps is None:
            continue

        mean_smooth = smooth(mean)

        ax.plot(steps, mean_smooth, color=color, label=label, linewidth=2)
        ax.fill_between(steps, mean_smooth - std, mean_smooth + std, alpha=0.2, color=color)

    ax.set_xlabel("Environment Steps")
    ax.set_ylabel("Episodic Return")
    ax.set_title(title)
    ax.legend(loc="lower right")
    ax.grid(True, alpha=0.3)
    ax.set_xlim(left=0)


def main():
    base_dir = "."

    envs = [
        ("quadruped-walk", "Quadruped-Walk (12D)"),
        ("quadruped-run", "Quadruped-Run (12D)"),
        ("dog_walk", "Dog-Walk (38D)"),
        ("dog_run", "Dog-Run (38D)"),
    ]

    os.makedirs("figures", exist_ok=True)

    fig, axes = plt.subplots(2, 2, figsize=(10, 8))
    axes = axes.flatten()

    for idx, (env_name, title) in enumerate(envs):
        plot_single_env(axes[idx], base_dir, env_name, title)

    plt.tight_layout()
    plt.savefig("figures/all_results.pdf", bbox_inches="tight")
    plt.savefig("figures/all_results.png", bbox_inches="tight", dpi=150)
    print("Saved: figures/all_results.pdf")
    print("Saved: figures/all_results.png")
    plt.close()

    print("Done!")


if __name__ == "__main__":
    main()