from stable_baselines3.sac import SAC
from stable_baselines3.common.callbacks import CheckpointCallback
from stable_baselines3.common.noise import NormalActionNoise, OrnsteinUhlenbeckActionNoise
from stable_baselines3.common.env_util import make_vec_env
from stable_baselines3.common.logger import configure

import gymnasium as gym
import numpy as np
import argparse

from wrapper import *


if __name__ == '__main__':
    # env_id = 'Ant-v5'
    # sub_name = ''
    # env = make_vec_env(env_id, n_envs=4, seed=2025, env_kwargs={'include_cfrc_ext_in_observation': False, 'terminate_when_unhealthy': False})

    parser = argparse.ArgumentParser()
    parser.add_argument('--env', type=str, default='Ant-v5')
    parser.add_argument('--sub_name', type=str, default='front_left_front_right')

    args = parser.parse_args()

    env_id = args.env
    sub_name = args.sub_name

    if env_id == 'HalfCheetah-v5':
        if sub_name == 'front':
            joints_status = [0, 0, 0, 1, 1, 1]
        elif sub_name == 'back':
            joints_status = [1, 1, 1, 0, 0, 0]
        elif sub_name == 'full':
            joints_status = [1, 1, 1, 1, 1, 1]
        env = make_vec_env(env_id, n_envs=4, seed=2025, wrapper_class=DisabledHalfCheetahOnly, wrapper_kwargs={'joints_status': joints_status})
    elif env_id == 'Ant-v5':
        if sub_name == 'front_left_back_left':
            joints_status = [0, 0, 1, 1, 0, 0, 1, 1]
        elif sub_name == 'front_right_back_right':
            joints_status = [1, 1, 0, 0, 1, 1, 0, 0]
        elif sub_name == 'front_left_back_right':
            joints_status = [1, 1, 1, 1, 0, 0, 0, 0]
        elif sub_name == 'front_right_back_left':
            joints_status = [0, 0, 0, 0, 1, 1, 1, 1]
        elif sub_name == 'front_left_front_right':
            joints_status = [0, 0, 1, 1, 1, 1, 0, 0]
        elif sub_name == 'back_left_back_right':    
            joints_status = [1, 1, 0, 0, 0, 0, 1, 1]
        env = make_vec_env(env_id, n_envs=4, seed=2025, env_kwargs={'terminate_when_unhealthy': False, 'include_cfrc_ext_in_observation': False}, 
                           wrapper_class=DisabledAntOnly, wrapper_kwargs={'joints_status': joints_status})

    env_id = f'{env_id}_{sub_name}'
    saved_model_path = f"runs/rl/{env_id}/ckpt/"
    logger_path = f"runs/rl/{env_id}/logs/"

    checkpoint_callback = CheckpointCallback(save_freq=10000,
                                             save_path=saved_model_path,
                                             name_prefix="sac",
                                             save_replay_buffer=False,
                                            )
    
    action_noise = OrnsteinUhlenbeckActionNoise(mean=np.zeros(env.action_space.shape[0]), sigma=0.2 * np.ones(env.action_space.shape[0]))
    model = SAC("MlpPolicy", env, verbose=1, policy_kwargs={'net_arch': [300, 400]}, 
                stats_window_size=20, learning_starts=10_000, action_noise=action_noise, buffer_size=500_000,)

    new_logger = configure(logger_path, ["stdout", "csv", "tensorboard", "log"])
    model.set_logger(new_logger)

    model.learn(total_timesteps=10000000, callback=checkpoint_callback)

