import os
import gym
import argparse
from stable_baselines3.common.vec_env import DummyVecEnv, VecNormalize
from stable_baselines3.common.callbacks import EvalCallback
from stable_baselines3 import SAC
from stable_baselines3.common.utils import set_random_seed
from stable_baselines3.common.monitor import Monitor
from stable_baselines3.common.logger import configure
from robosuite.wrappers import GymWrapper
import robosuite as suite


parser = argparse.ArgumentParser()
parser.add_argument("--seed", type=int, nargs='?', default=1)
parser.add_argument("--device", type=str, nargs='?', default='cuda:0')
parser.add_argument("--folder", type=str, nargs='?', default='')
parser.add_argument("--env", type=str, nargs='?', default='Door')
parser.add_argument("--timesteps", type=int, nargs='?', default=int(5e5))
args = parser.parse_args()

log_folder = f'{args.folder}/source/logs/{args.env}/{str(args.seed)}/'
vec_env = DummyVecEnv([lambda: GymWrapper(
        suite.make(
            args.env,
            robots="Panda", 
            use_camera_obs=False, 
            has_offscreen_renderer=False, 
            has_renderer=False,  
            reward_shaping=True, 
            control_freq=20, 
        )
    )]*1)
eval_env = DummyVecEnv([lambda: Monitor(GymWrapper(
        suite.make(
            args.env,
            robots="Panda", 
            use_camera_obs=False,  
            has_offscreen_renderer=False, 
            has_renderer=False,  
            reward_shaping=True, 
            control_freq=20,
        )
    ))]*1)

print(f"State Dimension: {vec_env.observation_space}, Action Dimension: {vec_env.action_space}")
vec_env.seed(seed=args.seed)
eval_env.seed(seed=args.seed)
set_random_seed(seed = args.seed)
new_logger = configure(log_folder, ["stdout", "csv", ])
eval_callback = EvalCallback(eval_env, best_model_save_path=log_folder, log_path=log_folder, eval_freq=2000, deterministic=True, render=False, n_eval_episodes = 5)
# Automatically normalize the input features and reward

model = SAC("MlpPolicy", vec_env, verbose = 1, device = args.device, seed = args.seed, learning_rate=3e-4 ,gamma=0.9)
print(model.actor)
model.set_logger(new_logger)
model.learn(total_timesteps=args.timesteps, progress_bar = True, callback=eval_callback)
