import argparse
import os
import sys
from pathlib import Path
from typing import Any

import numpy as np
import torch
import yaml
from matplotlib import cm
from matplotlib import pyplot as plt
from stable_baselines3.common.env_util import make_vec_env
from stable_baselines3.common.vec_env import DummyVecEnv

current_path = Path(os.path.abspath(__file__)).parent
project_root = current_path.parent
sys.path.append(project_root.as_posix())

import pickle
import re

from compression_autoencoder.autoencoders.rl_neural_continuous_autoencoder import (
    RLNeuralContinuousAutoencoder,
)
from compression_autoencoder.optim.pgpe_learner import PGPELearner
from compression_autoencoder.policies.policy import Policy
from compression_autoencoder.utils.history import History
from compression_autoencoder.utils.misc import (
    get_seed_sequence,
    resolve_source_dir,
    set_seeds,
)
from scripts.constants import INPUT_SCALERS, WRAPPER_CLASSES

with open(current_path / "parameters.yml") as f:
    stored_params = yaml.load(f, Loader=yaml.SafeLoader)


def prep_arg_parser() -> argparse.ArgumentParser:
    parser = argparse.ArgumentParser(
        description="Run PGPELearner with specified parameters."
    )
    defaults = stored_params["defaults"]
    parser.add_argument(
        "--source_dir",
        type=str,
        help="Location of directory that contains the trained autoencoder or policy files",
    )
    parser.add_argument(
        "--num_generations",
        type=int,
        default=defaults["num_generations"],
        help="Number of generations to run",
    )
    parser.add_argument(
        "--env",
        type=str,
        # choices=["MountainCarContinuous-v0", "Hopper-v5", "Reacher-v5"],
        choices=["MountainCarContinuous-v0", "Reacher-v5"],
        default=stored_params["env"],
        help="Environment to run",
    )
    parser.add_argument(
        "--num_envs",
        type=int,
        default=defaults["num_envs"],
        help="Number of parallel environments",
    )
    parser.add_argument(
        "--popsize",
        type=int,
        default=defaults["popsize"],
        help="Population size for PGPE",
    )
    parser.add_argument(
        "--center_learning_rate",
        type=float,
        default=defaults["center_learning_rate"],
        help="Learning rate for PGPE",
    )
    parser.add_argument(
        "--stdev_learning_rate",
        type=float,
        default=defaults["stdev_learning_rate"],
        help="Learning rate's std dev for PGPE",
    )
    parser.add_argument(
        "--stdev_init",
        type=float,
        default=defaults["stdev_init"],
        help="Standard deviation initial value for PGPE",
    )
    parser.add_argument(
        "--seed", type=int, default=defaults["seed"], help="Seed for RNG"
    )
    parser.add_argument(
        "--target_space",
        type=str,
        default=defaults["target_space"],
        choices=["latent", "parameter"],
        help="Target space for the policy, either latent, or parameter",
    )
    parser.add_argument(
        "--num_runs",
        type=int,
        default=defaults["num_runs"],
        help="Number of runs to perform",
    )
    parser.add_argument(
        "--reward_type",
        type=str,
        default="standard",
        help="Type of reward to use in the environment",
    )
    parser.add_argument(
        "--update_type",
        type=str,
        default=defaults["update_type"],
        choices=["reinforce", "natural"],
        help="Type of update to use, either reinforce or natural",
    )
    parser.add_argument(
        "--eval_freq_steps",
        type=int,
        default=defaults["eval_freq_steps"],
        help="Frequency of evaluation in environment steps",
    )
    return parser


def plot_pgpe_results(
    histories: list[dict[str, Any]],
    dest_dir: Path,
    target_space: str,
) -> None:
    """
    Generates and saves visualizations for PGPE optimization results.

    This function automatically infers the number of runs and latent dimensionality.
    It will not plot center trajectories for spaces with more than 3 dimensions.

    Args:
        histories: A list of history dictionaries, one for each run.
        dest_dir: The directory where the plot images will be saved.
        target_space: The optimization space (e.g., 'latent', 'param').
    """
    num_runs = len(histories)
    if num_runs == 0:
        print("Warning: No histories provided to plot.")
        return

    latent_dim: int | None = None
    if (
        target_space == "latent"
        and histories[0].get("center")
        and len(histories[0]["center"]) > 0
    ):
        latent_dim = np.asarray(histories[0]["center"][0]).shape[0]

    plt.style.use("seaborn-v0_8-whitegrid")
    dest_dir.mkdir(parents=True, exist_ok=True)
    print(f"Saving visualizations for {num_runs} run(s) to {dest_dir}...")

    # Plot 1: Learning Curves
    _plot_learning_curves(histories, num_runs, dest_dir)

    # Plot 2: Center Trajectories (only for latent space <= 3D)
    if target_space == "latent" and latent_dim is not None:
        _plot_center_trajectories(histories, num_runs, latent_dim, dest_dir)


def _plot_learning_curves(
    histories: list[dict[str, Any]], num_runs: int, dest_dir: Path
) -> None:
    """Plots validation reward learning curves."""
    fig, ax = plt.subplots(figsize=(12, 7))

    if num_runs == 1:
        rewards = np.array(histories[0]["val_reward"])
        ax.plot(rewards[:, 0], rewards[:, 1], color="dodgerblue", linewidth=2)
    else:
        all_x = [np.array(h["val_reward"])[:, 0] for h in histories if h["val_reward"]]
        if not all_x:
            print("Warning: No 'val_reward' data found to plot.")
            plt.close(fig)
            return

        common_x = np.linspace(max(x[0] for x in all_x), min(x[-1] for x in all_x), 100)

        interpolated_rewards = []
        for i in range(num_runs):
            run_rewards = np.array(histories[i]["val_reward"])
            interpolated_y = np.interp(common_x, run_rewards[:, 0], run_rewards[:, 1])
            interpolated_rewards.append(interpolated_y)
            ax.plot(
                run_rewards[:, 0],
                run_rewards[:, 1],
                color="steelblue",
                alpha=0.2,
                linewidth=1,
            )

        rewards_arr = np.array(interpolated_rewards)
        mean_rewards = np.mean(rewards_arr, axis=0)
        std_rewards = np.std(rewards_arr, axis=0)
        conf_interval = 1.96 * std_rewards / np.sqrt(num_runs)

        ax.plot(
            common_x, mean_rewards, color="coral", linewidth=2.5, label="Mean Reward"
        )
        ax.fill_between(
            common_x,
            mean_rewards - conf_interval,
            mean_rewards + conf_interval,
            color="coral",
            alpha=0.2,
            label="95% Confidence Interval",
        )
        ax.legend(fontsize=12)

    ax.set_xlabel("Environment Interactions", fontsize=14)
    ax.set_ylabel("Validation Reward", fontsize=14)
    ax.set_title(f"PGPE Learning Curves (across {num_runs} run(s))", fontsize=16)
    ax.grid(True, linestyle="--", alpha=0.6)
    plt.tight_layout()
    plt.savefig(dest_dir / "pgpe_learning_curves.png", dpi=300)
    plt.close(fig)


def _plot_center_trajectories(
    histories: list[dict[str, Any]], num_runs: int, latent_dim: int, dest_dir: Path
) -> None:
    """Plots the trajectories of the distribution center in latent space."""
    if latent_dim > 3:
        print(
            f"Skipping trajectory plot: Dimensionality ({latent_dim}D) is greater than 3."
        )
        return

    fig, ax = None, None

    if latent_dim == 1:
        fig, ax = plt.subplots(figsize=(12, 7))
        colors = cm.viridis(np.linspace(0, 1, num_runs))
        for i in range(num_runs):
            centers = np.array([c[0] for c in histories[i]["center"]])
            ax.plot(centers, color=colors[i], alpha=0.8, label=f"Run {i + 1}")
        ax.set_xlabel("Generation", fontsize=14)
        ax.set_ylabel("Center Value (Latent Dimension 1)", fontsize=14)
        ax.legend(fontsize=12)

    elif latent_dim == 2:
        fig, ax = plt.subplots(figsize=(10, 10))
        for i in range(num_runs):
            centers = np.array(histories[i]["center"])
            generations = np.arange(len(centers))

            # --- UPDATED: Add a connecting line ---
            ax.plot(
                centers[:, 0],
                centers[:, 1],
                color="gray",
                alpha=0.5,
                linewidth=1,
                zorder=1,
            )

            ax.scatter(
                centers[:, 0],
                centers[:, 1],
                c=generations,
                cmap="viridis",
                s=20,
                alpha=0.8,
                zorder=2,
            )
            ax.plot(
                centers[0, 0],
                centers[0, 1],
                "o",
                color="lime",
                markersize=8,
                label="Start" if i == 0 else "",
                zorder=3,
            )
            ax.plot(
                centers[-1, 0],
                centers[-1, 1],
                "x",
                color="red",
                markersize=10,
                mew=2,
                label="End" if i == 0 else "",
                zorder=3,
            )
        ax.set_xlabel("Latent Dimension 1", fontsize=14)
        ax.set_ylabel("Latent Dimension 2", fontsize=14)
        ax.set_aspect("equal", adjustable="box")
        fig.colorbar(cm.ScalarMappable(cmap="viridis"), ax=ax, label="Generation")
        ax.legend()

    elif latent_dim == 3:
        fig = plt.figure(figsize=(12, 12))
        ax = fig.add_subplot(projection="3d")
        for i in range(num_runs):
            centers = np.array(histories[i]["center"])
            generations = np.arange(len(centers))

            # --- UPDATED: Add a connecting line ---
            ax.plot(
                centers[:, 0],
                centers[:, 1],
                centers[:, 2],
                color="gray",
                alpha=0.5,
                linewidth=1,
                zorder=1,
            )

            ax.scatter(
                centers[:, 0],
                centers[:, 1],
                centers[:, 2],
                c=generations,
                cmap="viridis",
                s=20,
                alpha=0.8,
                zorder=2,
            )
            ax.plot(
                centers[0, 0],
                centers[0, 1],
                centers[0, 2],
                "o",
                color="lime",
                markersize=8,
                label="Start" if i == 0 else "",
                zorder=3,
            )
            ax.plot(
                centers[-1, 0],
                centers[-1, 1],
                centers[-1, 2],
                "x",
                color="red",
                markersize=10,
                mew=2,
                label="End" if i == 0 else "",
                zorder=3,
            )
        ax.set_xlabel("Latent Dim 1", fontsize=12)
        ax.set_ylabel("Latent Dim 2", fontsize=12)
        ax.set_zlabel("Latent Dim 3", fontsize=12)
        fig.colorbar(
            cm.ScalarMappable(cmap="viridis"), ax=ax, label="Generation", shrink=0.6
        )
        ax.legend()

    if fig and ax:
        ax.set_title(f"Center Trajectories in {latent_dim}D Latent Space", fontsize=16)
        ax.grid(True, linestyle="--", alpha=0.6)
        plt.tight_layout()
        plt.savefig(dest_dir / "pgpe_center_trajectories.png", dpi=300)
        plt.close(fig)


def init_from_target_space(args, source_dir, device):  # noqa: ANN001, ANN201
    precompute_parameters = None
    model = None
    should_skip = False

    # optimization in latent space
    if args.target_space == "latent":
        stats_path = source_dir / "latent" / "viz_stats.pkl"
        with open(stats_path, "rb") as f:
            viz_stats = pickle.load(f)

        with open(source_dir / "training_args.yml") as f:
            training_args = yaml.load(f, Loader=yaml.SafeLoader)

        selected_dir = resolve_source_dir(
            training_args["source_dir"], project_root, current_path
        )
        with open(selected_dir / "selection_args.yml") as f:
            selection_args = yaml.load(f, Loader=yaml.SafeLoader)

        generation_dir = resolve_source_dir(
            selection_args["source_dir"], project_root, current_path
        )
        with open(generation_dir / "args.yml") as f:
            generation_args = yaml.load(f, Loader=yaml.FullLoader)

        policy_args = stored_params["policy"]
        layer_shapes = [
            tuple(item)
            for item in policy_args["layer_shapes"][generation_args["policy_shape"]]
        ]
        activation_func = generation_args["activation_func"]
        last_activation_func = generation_args["last_activation_func"]

        sample_policy = Policy(
            layer_shapes=layer_shapes,
            activation_func=activation_func,
            last_activation_func=last_activation_func,
            input_scaler=INPUT_SCALERS[args.env],
            device=device,
        )

        model = RLNeuralContinuousAutoencoder(
            sample_policy=sample_policy,
            latent_dim=training_args["latent_dim"],
            encoder_layers_shapes=training_args["encoder_layers_shapes"],
            decoder_layers_shapes=training_args["decoder_layers_shapes"],
            activation_func=training_args["activation_func"],
            input_scaler=INPUT_SCALERS[f"Autoencoder{args.env}"],
            device=device,
        )
        model.load(source_dir / "autoencoder.pth")

        num_params = training_args["latent_dim"]

        def precompute_parameters(x: np.ndarray) -> torch.Tensor:
            with torch.no_grad():
                return model.decode(torch.from_numpy(x).float().to(device))

        def get_policy():
            return sample_policy

    # optimization in parameter space
    if args.target_space == "parameter":
        policy_file = source_dir / "policy.yml"
        assert policy_file.exists(), (
            f'Expected policy file at {source_dir}, make sure it\'s named "policy.yml"'
        )

        with open(policy_file) as f:
            policy_args = yaml.load(f, Loader=yaml.FullLoader)

        layer_shapes = [tuple(item) for item in policy_args["layer_shapes"]]
        activation_func = policy_args["activation_func"]
        last_activation_func = policy_args["last_activation_func"]

        sample_policy = Policy(
            layer_shapes=layer_shapes,
            activation_func=activation_func,
            last_activation_func=last_activation_func,
            input_scaler=INPUT_SCALERS[args.env],
            device=device,
        )

        def get_policy():  # Needed to randomly set the starting weights for each run
            return Policy(
                layer_shapes=layer_shapes,
                activation_func=activation_func,
                last_activation_func=last_activation_func,
                input_scaler=INPUT_SCALERS[args.env],
                device=device,
            )

        num_params = sample_policy.count_params()[0]

    return get_policy, num_params, precompute_parameters, should_skip


def main() -> None:
    args = prep_arg_parser().parse_args()

    if args.target_space == "latent":
        match = re.search(r"dim_(\d+)", str(args.source_dir))
        if match:
            dim_x = int(match.group(1))
            print(f"Found dimension: {dim_x}")
        else:
            dim_x = None
            raise ValueError("Could not find 'dim_x' in source_dir.")

    source_dir = resolve_source_dir(args.source_dir, project_root, current_path)

    device = "cpu"
    # device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    rng = set_seeds(args.seed)

    policy_getter, num_params, precompute_parameters, should_skip = (
        init_from_target_space(args, source_dir, device)
    )

    if should_skip:
        print(f"Skipping optimization for {source_dir} due to latent space collapse.")
        return

    # Environment setup
    env_str = args.env
    vec_env = make_vec_env(
        env_str,
        n_envs=args.num_envs,
        vec_env_cls=DummyVecEnv,
        wrapper_class=WRAPPER_CLASSES[env_str],  # type: ignore
        wrapper_kwargs={"reward_type": args.reward_type},
        seed=args.seed,
    )

    # Training loop
    histories: list[History] = []
    run_seeds = get_seed_sequence(rng, args.num_runs)
    for i, seed in enumerate(run_seeds):
        print(f"Run {i + 1}/{args.num_runs}")

        learner = PGPELearner(
            popsize=args.popsize,
            num_generations=args.num_generations,
            num_params=num_params,
            parameter_to_weights=precompute_parameters,
            env=vec_env,
            device=device,
            sample_policy=policy_getter(),
            exact_step_count=True,
            update_type=args.update_type,
            seed=seed,
            center_init_dist="zeros" if args.target_space == "latent" else "policy",
            pgpe_kwargs={
                "center_learning_rate": args.center_learning_rate,
                "stdev_learning_rate": args.stdev_learning_rate,
                "stdev_init": args.stdev_init,
                "use_lr_scheduler": False if env_str == "MountainCarContinuous-v0" else True,
            },
        )
        histories.append(
            learner.learn(verbose=True, eval_freq_steps=args.eval_freq_steps)
        )

    # parse env name for directory name
    if args.env == "MountainCarContinuous-v0":
        env_dir = "mountain_car"
    if args.env == "Reacher-v5":
        env_dir = "reacher"

    if args.target_space == "latent":
        dest_dir = (
            project_root
            / f"optimization_results_{env_dir}_lr_{args.center_learning_rate}_stdev_{args.stdev_init}_pop_{args.popsize}"
            / f"{source_dir.name}_dim_{dim_x}"
            / f"{args.reward_type}"
        )
    if args.target_space == "parameter":
        dest_dir = (
            project_root
            / "optimization_results_new"
            / f"{env_dir}"
            / f"{source_dir.name}"
            / f"{args.reward_type}"
        )
    print(f"Saving model to {dest_dir} ...")
    dest_dir.mkdir(parents=True, exist_ok=True)

    plot_pgpe_results(
        histories=[h.history for h in histories],
        dest_dir=dest_dir,
        target_space=args.target_space,
    )

    with open(dest_dir / "pgpe_histories.pkl", "wb") as f:
        pickle.dump([h.history for h in histories], f)

    # Save the arguments used for this run to a YAML file
    with open(dest_dir / "pgpe_run_args.yml", "w") as f:
        yaml.dump(vars(args), f, default_flow_style=False)


if __name__ == "__main__":
    main()
