import argparse
import os
import sys
from pathlib import Path

import numpy as np
import yaml

# import yaml
from stable_baselines3 import DDPG, PPO, SAC, TD3
from stable_baselines3.common.callbacks import EvalCallback
from stable_baselines3.common.env_util import make_vec_env
from stable_baselines3.common.noise import OrnsteinUhlenbeckActionNoise

current_path = Path(os.path.abspath(__file__)).parent
project_root = current_path.parent
sys.path.append(project_root.as_posix())

from scripts.constants import WRAPPER_CLASSES

with open(current_path / "parameters.yml") as f:
    stored_params = yaml.load(f, Loader=yaml.SafeLoader)


def prep_arg_parser() -> argparse.ArgumentParser:
    parser = argparse.ArgumentParser(description="Run Baseline RL Algorithms")
    parser.add_argument(
        "--algorithm",
        type=str,
        choices=["td3", "ddpg", "sac", "ppo"],
        help="RL algorithm to use",
        required=True,
    )
    parser.add_argument(
        "--env",
        type=str,
        # choices=["MountainCarContinuous-v0", "Hopper-v5", "Reacher-v5"],
        choices=["MountainCarContinuous-v0", "Reacher-v5"],
        default=stored_params["env"],
        help="Environment to run",
    )
    parser.add_argument(
        "--reward",
        type=str,
        default="height",
        choices=[
            "length",
            "standard",
            "speed",
            "left",
            "height",
            "rot_ccw",
            "rot_cw",
            "radial_speed",
            "quad_I",
            "quad_II",
            "quad_III",
            "quad_IV",
        ],
        help="Which reward to consider",
    )
    parser.add_argument(
        "--num_train",
        type=int,
        default=1,
        help="Number of models to train",
    )
    parser.add_argument(
        "--start_seed",
        type=int,
        default=0,
        help="Initial seed for training",
    )
    parser.add_argument(
        "--policy_shape",
        type=int,
        nargs="+",
        help="Sizes of hidden layers in the policy network, e.g. --policy_shape 8 4 "
        "uses a net with two hidden layers of 8 and 4 neurons respectively",
        required=True,
    )
    parser.add_argument(
        "--total_timesteps",
        type=int,
        default=300000,
        help="Total number of samples (env steps) to train on",
    )
    parser.add_argument(
        "--eval_frequency",
        type=int,
        default=1000,
        help="Number of timesteps between evaluations",
    )
    return parser


def main() -> None:
    args = prep_arg_parser().parse_args()

    # parse policy shape for directory name
    policy_shape = "_".join(map(str, args.policy_shape))

    # parse env name for directory name
    if args.env == "MountainCarContinuous-v0":
        env_dir = "mountain_car"
    if args.env == "Reacher-v5":
        env_dir = "reacher"

    dest_dir = (
        project_root
        / "baseline_results"
        / f"{env_dir}_{args.algorithm}_{args.reward}_policy_{policy_shape}_seed_{args.start_seed}"
    )
    # dest_dir.mkdir(parents=True, exist_ok=True)

    # Parallel environments
    env_str = args.env
    vec_env = make_vec_env(
        env_str,
        n_envs=8,
        wrapper_class=WRAPPER_CLASSES[env_str],
        wrapper_kwargs={"reward_type": args.reward},
    )
    eval_env = make_vec_env(
        env_str,
        n_envs=8,
        wrapper_class=WRAPPER_CLASSES[env_str],
        wrapper_kwargs={"reward_type": args.reward},
    )

    for seed in range(args.start_seed, args.start_seed + args.num_train):
        eval_callback = EvalCallback(
            eval_env,
            best_model_save_path=dest_dir / f"seed_{seed}",
            log_path=dest_dir / f"seed_{seed}",
            eval_freq=args.eval_frequency,
            deterministic=True,
            render=False,
            n_eval_episodes=8,
        )

        if args.algorithm == "td3":
            model = TD3(
                "MlpPolicy",
                vec_env,
                verbose=1,
                gamma=0.99,
                action_noise=OrnsteinUhlenbeckActionNoise(
                    mean=np.zeros(1), sigma=np.array([0.75])
                ),
                seed=seed,
                policy_kwargs={"net_arch": {"pi": args.policy_shape, "qf": [400, 300]}},
                device="cpu",
            )
        if args.algorithm == "ddpg":
            model = DDPG(
                "MlpPolicy",
                vec_env,
                buffer_size=50000,
                verbose=1,
                gamma=0.99,
                action_noise=OrnsteinUhlenbeckActionNoise(
                    mean=np.zeros(1),
                    sigma=np.array([0.65]),
                ),
                policy_kwargs={"net_arch": {"pi": args.policy_shape, "qf": [400, 300]}},
                # device="cpu",
                seed=seed,
            )
        if args.algorithm == "sac":
            model = SAC(
                "MlpPolicy",
                vec_env,
                verbose=1,
                gamma=0.99,
                batch_size=256,
                buffer_size=50000,
                ent_coef="0.1",
                gradient_steps=32,
                learning_rate=0.0003,
                learning_starts=0,
                tau=0.01,
                train_freq=32,
                use_sde=True,
                policy_kwargs={"net_arch": args.policy_shape, "log_std_init": -3.67},
                device="cpu",
                seed=seed,
            )
        if args.algorithm == "ppo":
            model = PPO(
                "MlpPolicy",
                vec_env,
                normalize_advantage=True,
                verbose=1,
                learning_rate=1e-04,
                n_steps=32,
                batch_size=256,
                n_epochs=4,
                gamma=0.9999,
                gae_lambda=0.9,
                clip_range=0.1,
                ent_coef=0.01,
                vf_coef=0.19,
                max_grad_norm=5,
                use_sde=True,
                device="cpu",
                policy_kwargs={
                    "net_arch": args.policy_shape,
                    "log_std_init": -3.29,
                    "ortho_init": False,
                },
                seed=seed,
            )

        model.learn(
            total_timesteps=args.total_timesteps,
            progress_bar=True,
            callback=eval_callback,
        )
        model.save(dest_dir / f"seed_{seed}" / "last_model")


if __name__ == "__main__":
    main()
