import os
import os.path as osp
import gym
import numpy as np
import torch
from large_rl.commons.args import get_all_args
from large_rl.commons.seeds import set_randomSeed
from large_rl.commons.utils import VideoFrameBuffer, logging, save_mp4

from sac_dir.stable_baselines3 import SAC
from sac_dir.stable_baselines3.common.vec_env.dummy_vec_env import DummyVecEnv

# create an Wrapper for Envs, to Add an parameter called render_mode="rgb_array". 
# It is required for rendering the video in Stable baseline3
class RenderModeWrapper(gym.Wrapper):
    def __init__(self, env, render_mode="rgb_array"):
        super().__init__(env)
        self.render_mode = render_mode


def create_vector_environment(args: dict):
    if args["env_name"] == "mujoco-reacher":
        from large_rl.envs.reacher.reacher import create_environment as _fn
        create_environment = _fn
    elif args["env_name"] == "mujoco-ant":
        from large_rl.envs.reacher.ant_v4 import create_environment as _fn
        create_environment = _fn
    elif args["env_name"] == "mujoco-half_cheetah":
        from large_rl.envs.reacher.half_cheetah_v4 import create_environment as _fn
        create_environment = _fn
    elif args["env_name"] == "mujoco-hopper":
        from large_rl.envs.reacher.hopper_v4 import create_environment as _fn
        create_environment = _fn
    elif args["env_name"] == "mujoco-humanoid":
        from large_rl.envs.reacher.humanoid_v4 import create_environment as _fn
        create_environment = _fn
    elif args["env_name"] == "mujoco-humanoidstandup":
        from large_rl.envs.reacher.humanoidstandup_v4 import create_environment as _fn
        create_environment = _fn
    elif args["env_name"] == "mujoco-inverted_double_pendulum":
        from large_rl.envs.reacher.inverted_double_pendulum_v4 import create_environment as _fn
        create_environment = _fn
    elif args["env_name"] == "mujoco-inverted_pendulum":
        from large_rl.envs.reacher.inverted_pendulum_v4 import create_environment as _fn
        create_environment = _fn
    elif args["env_name"] == "mujoco-pusher":
        from large_rl.envs.reacher.pusher_v4 import create_environment as _fn
        create_environment = _fn
    elif args["env_name"] == "mujoco-swimmer":
        from large_rl.envs.reacher.swimmer_v4 import create_environment as _fn
        create_environment = _fn
    elif args["env_name"] == "mujoco-walker2d":
        from large_rl.envs.reacher.walker2d_v4 import create_environment as _fn
        create_environment = _fn
    else: raise ValueError

    create_environment_r = lambda: RenderModeWrapper(create_environment(args=args, seed=None), render_mode="rgb_array")

    train_env = DummyVecEnv([create_environment_r for _ in range(args["num_envs"])])
    # eval_env = DummyVecEnv([create_environment_r for _ in range(args["num_envs"])])
    eval_env = create_environment_r()
    train_env.seed(args["seed"])
    eval_env.seed(args["seed"])
    return train_env, eval_env

def train(model, train_step):
    # model.train()
    model.learn(total_timesteps=train_step, reset_num_timesteps=False)
    return model

def eval(model, eval_env, current_step, args, wlogger=None):
    # model.eval()
    frame_buffer = list()
    save_video = args["reacher_save_video"]
    eval_ep_rewards = []
    eval_loop = args["eval_num_episodes"]
    for eval_epoch in range(eval_loop):
        save_video = (eval_epoch == 0 and save_video)
        eval_ep_reward = 0
        done = False
        obs = eval_env.reset()
        while not done:
            action, _states = model.predict(obs, deterministic=True)
            obs, reward, done, info = eval_env.step(action)
            if save_video:
                frame_buffer.append(eval_env.render(mode="rgb_array"))
            # only add the reward where done is not True
            eval_ep_reward += reward
            # eval_ep_reward += reward
            # VecEnv resets automatically
        eval_ep_rewards.append(eval_ep_reward)

    if wlogger is not None:
        wlog_dict = {"eval/ep_return": np.mean(eval_ep_rewards)}
        # save video
        if len(frame_buffer) > 0:
            save_name = '%s_%s' % (str(current_step), 'eval')
            save_dir = os.path.join(args['video_saving_dir'], 'test')
            save_mp4(frame_buffer, save_dir, save_name, fps=args['vid_fps'], no_frame_drop=True)
            saved_file_name = '%s.mp4' % osp.join(save_dir, save_name)
            print('Rendered frames to %s' % saved_file_name)


if __name__ == "__main__":
    args = get_all_args()
    args = vars(args)
    # set seed
    if args["device"] == "cuda":
        torch.backends.cudnn.benchmark = False
        torch.backends.cudnn.deterministic = True
        torch.backends.cudnn.enabled = False
    set_randomSeed(seed=args["seed"])

    wlogger = None
    
    # set dir
    if args["save_dir"] != "":
        args['video_saving_dir'] = os.path.join(args['reacher_video_dir'], args["save_dir"])
    else:
        args['video_saving_dir'] = os.path.join(args['reacher_video_dir'],
                                                'seed{:05d}debug2022'.format(args["seed"]))
    
    # create env
    train_env, eval_env = create_vector_environment(args)
    # create model
    model = SAC("MlpPolicy", train_env, verbose=0, wlogger=wlogger, device=args["device"], gradient_steps=-1)

    # start train
    step = 0
    total_step = args["total_ts"]
    eval_freq = args["eval_freq"]

    if args["per_train_ts"] is not None:
        args["num_epochs"] = (args["total_ts"] // args["per_train_ts"]) // args["num_envs"]
    else:
        args["per_train_ts"] = (args["total_ts"] // args["num_epochs"]) // args["num_envs"]
    eval_freq = args["eval_freq"] * args["num_envs"] * args["per_train_ts"]
    print("eval_freq is {}".format(eval_freq))

    while step < total_step:
        if step % 100 == 0: 
            logging("start step:{}".format(step))
        if step % eval_freq == 0:
            eval(model, eval_env=eval_env, args=args, current_step=step, wlogger=wlogger)
            logging("eval step:{} Finished".format(step))
        model = train(model, eval_freq)
        step += eval_freq

    train_env.close()
    eval_env.close()