import argparse
import os
import sys
from pathlib import Path

import matplotlib.pyplot as plt
import numpy as np

current_path = Path(os.path.abspath(__file__)).parent
project_root = current_path.parent
sys.path.append(project_root.as_posix())

from compression_autoencoder.utils.misc import (
    resolve_source_dir,
)


def prep_arg_parser() -> argparse.ArgumentParser:
    parser = argparse.ArgumentParser(description="Plot data in baseline_results")
    parser.add_argument(
        "--source_dir",
        type=str,
        help="Location of directory that contains the baseline results",
        required=True,
    )
    return parser


def main() -> None:
    args = prep_arg_parser().parse_args()

    source_dir = resolve_source_dir(args.source_dir, project_root, current_path)
    seeds_dirs = sorted(source_dir.glob("seed_*"))

    num_seeds = len(seeds_dirs)
    all_timesteps = []
    all_results = []

    # Load data from each seed
    for seed in seeds_dirs:
        data = np.load(f"{seed}/evaluations.npz")
        timesteps = data["timesteps"]
        results = data["results"]

        all_timesteps.append(timesteps)
        all_results.append(results)

    # Check all timesteps are the same
    for t in all_timesteps[1:]:
        if not np.array_equal(all_timesteps[0], t):
            raise ValueError("Timesteps do not match across seeds.")

    timesteps = all_timesteps[0]
    all_means = [r.mean(axis=1) for r in all_results]
    all_means = np.array(all_means)  # shape: (num_seeds, num_evals)

    dest_dir = source_dir / "plots"
    dest_dir.mkdir(parents=True, exist_ok=True)
    # Plot 1: Single runs
    plt.figure()
    for i in range(num_seeds):
        plt.plot(timesteps, all_means[i], label=f"Seed {i + 1}")
    plt.xlabel("Timesteps")
    plt.ylabel("Mean Reward")
    plt.title("Evaluation Reward per Seed")
    plt.grid(True)
    plt.legend()
    plt.savefig(dest_dir / "single_seed_evaluation_rewards.png")

    # Plot 2: Mean and std across seeds
    mean_across_seeds = all_means.mean(axis=0)
    std_across_seeds = all_means.std(axis=0)

    plt.figure()
    plt.plot(timesteps, mean_across_seeds, label="Mean Reward")
    plt.fill_between(
        timesteps,
        mean_across_seeds - std_across_seeds,
        mean_across_seeds + std_across_seeds,
        color="b",
        alpha=0.2,
        label="Std. Dev.",
    )
    plt.xlabel("Timesteps")
    plt.ylabel("Mean Reward")
    plt.title("Mean Evaluation Reward over Timesteps (All Seeds)")
    plt.grid(True)
    plt.legend()
    plt.savefig(dest_dir / "mean_std_evaluation_reward.png")


if __name__ == "__main__":
    main()
