import numpy as np
from matplotlib import pyplot as plt

num_seeds = 1
all_timesteps = []
all_results = []
PLOT_CODE = True

# Load data from each seed
for i in range(num_seeds):
    path = f"temp/seed_{i}/evaluations.npz"
    data = np.load(path)
    timesteps = data["timesteps"]
    results = data["results"]
    all_timesteps.append(timesteps)
    all_results.append(results)

# Check all timesteps are the same
for t in all_timesteps[1:]:
    if not np.array_equal(all_timesteps[0], t):
        raise ValueError("Timesteps do not match across seeds.")

timesteps = all_timesteps[0]
all_means = [r.mean(axis=1) for r in all_results]
all_means = np.array(all_means)  # shape: (num_seeds, num_evals)

# Plot 1: Single runs
plt.figure()
for i in range(num_seeds):
    plt.plot(timesteps, all_means[i], label=f"Seed {i}")

plt.xlabel("Timesteps")
plt.ylabel("Mean Reward")
plt.title("Evaluation Reward per Seed")
plt.grid(True)
plt.legend()
plt.savefig("single_seed_evaluation_rewards.png")

# Plot 2: Mean and std across seeds
mean_across_seeds = all_means.mean(axis=0)
std_across_seeds = all_means.std(axis=0)

plt.figure()
plt.plot(timesteps, mean_across_seeds, label="Mean Reward")
plt.fill_between(
    timesteps,
    mean_across_seeds - std_across_seeds,
    mean_across_seeds + std_across_seeds,
    color="b",
    alpha=0.2,
    label="Std. Dev.",
)
plt.xlabel("Timesteps")
plt.ylabel("Mean Reward")
plt.title("Mean Evaluation Reward over Timesteps (All Seeds)")
plt.grid(True)
plt.legend()
plt.savefig("mean_std_evaluation_reward.png")

if PLOT_CODE:
    plt.figure()
    for i in range(num_seeds):
        path = f"temp/seed_{i}/latent_codes.npz"
        data = np.load(path)
        codes = np.array(data["latent_codes"])
        plt.plot(
            codes[:, :, 0],
            codes[:, :, 1],
            marker="o",
            label=f"Run {i + 1}",
            alpha=0.7,
        )
    plt.xlabel("Center Dim 1")
    plt.ylabel("Center Dim 2")
    plt.title("Center Trajectories in Latent Space")
    plt.legend()
    plt.grid(True, linestyle="--", alpha=0.5)
    plt.gca().set_aspect("equal", adjustable="box")
    plt.tight_layout()
    plt.savefig("center_trajectories.png")
