import os
from argparse import ArgumentParser

import gymnasium as gym
from stable_baselines3 import SAC
from stable_baselines3.common.monitor import Monitor
from stable_baselines3.common.vec_env import DummyVecEnv, VecVideoRecorder

from customPolicy import CustomActorCriticPolicy
from protosac import ProtoSAC

TOTAL_TIMESTEPS = 30000

model_kwargs = {
    "learning_rate": 1e-3,
    'verbose': 1
}

ENVIRONMENT = {
    0: "Pendulum-v1",
    1: "LunarLanderContinuous-v3",
    2: "MountainCarContinuous-v0",
}

if __name__ == "__main__":
    parser = ArgumentParser()
    parser.add_argument('--environment', type=int, choices=ENVIRONMENT.keys(),
                        help=f"Choose the OpenAI gym environment to use 0: Pendulum-v1, 1: LunarLanderContinuous-v3, 2 : MountainCarContinuous-v0. Default: {ENVIRONMENT[0]}",
                        default=0)
    parser.add_argument('--episodes', type=int,
                        help=f"The number of episodes to run in the given environment. Default: {TOTAL_TIMESTEPS}",
                        default=TOTAL_TIMESTEPS)
    parser.add_argument('--baseline', type=bool,
                        help=f"Specify if you want to use the baseline. Default: False",
                        default=False)

    args = parser.parse_args()

    name_env = ENVIRONMENT[args.environment]

    if args.environment == 1:  # LunarLanderContinuous-v3 # hyperparameters at https://github.com/DLR-RM/rl-baselines3-zoo/blob/master/hyperparams/sac.yml

        model_kwargs['learning_rate'] = 7.3e-4
        model_kwargs['tau'] = 0.01
        model_kwargs['learning_starts'] = 10_000
        model_kwargs['policy_kwargs'] = dict(net_arch=[400, 300])

    if args.environment == 2:  # if MountainCarContinuous-v0
        model_kwargs['learning_rate'] = 3e-4
        model_kwargs['buffer_size'] = 50000
        model_kwargs['batch_size'] = 512
        model_kwargs['ent_coef'] = 0.1
        model_kwargs['train_freq'] = 32
        model_kwargs['gradient_steps'] = 32
        model_kwargs['gamma'] = 0.9999
        model_kwargs['tau'] = 0.01
        model_kwargs['learning_starts'] = 0
        model_kwargs['use_sde'] = True
        model_kwargs['policy_kwargs'] = dict(log_std_init=-3.67, net_arch=[64, 64])


    def make_env():
        env = gym.make(name_env, render_mode="rgb_array")

        env = Monitor(env)  # record stats such as returns
        return env


    env = DummyVecEnv([make_env])

    env = VecVideoRecorder(
        env,
        f"videos/{name_env}",
        record_video_trigger=lambda x: (x + 200) % (args.episodes // 4) == 0,
        video_length=200,
    )

    if args.baseline:
        model = SAC('MlpPolicy', env, **model_kwargs)
    else:
        model = ProtoSAC(CustomActorCriticPolicy, env, **model_kwargs)

    # start training
    model.learn(total_timesteps=args.episodes, log_interval=4)

os.makedirs(f"models/{name_env}", exist_ok=True)
# del model # remove to demonstrate saving and loading
model.save(f"models/{name_env}")

env.close()
