# ===== add python path ===== #
import glob
import sys
import os
PATH = os.getcwd()
for dir_idx, dir_name in enumerate(PATH.split('/')):
    dir_path = '/'.join(PATH.split('/')[:(dir_idx+1)])
    file_list = [os.path.basename(sub_dir) for sub_dir in glob.glob(f"{dir_path}/.*")]
    if '.git_package' in file_list:
        PATH = dir_path
        break
if not PATH in sys.path:
    sys.path.append(PATH)
# =========================== #

from utils.vectorize import DobroSubprocVecEnv3
from utils.normalize import RunningMeanStd
from utils.vectorize import Callback
import utils.register

from stable_baselines3.common.env_util import make_vec_env
from stable_baselines3 import PPO
import numpy as np
import argparse
import torch
import wandb
import gym

def getParser():
    parser = argparse.ArgumentParser(description='RL')
    # common
    parser.add_argument('--wandb',  action='store_true', help='use wandb?')
    parser.add_argument('--test',  action='store_true', help='test or train?')
    parser.add_argument('--resume',  action='store_true', help='resume?')
    parser.add_argument('--device', type=str, default='gpu', help='gpu or cpu.')
    parser.add_argument('--name', type=str, default='RL', help='save name.')
    parser.add_argument('--save_freq', type=int, default=int(1e6), help='# of time steps for save.')
    parser.add_argument('--total_steps', type=int, default=int(5e6), help='total training steps.')
    parser.add_argument('--seed', type=int, default=1, help='seed number.')
    parser.add_argument('--gpu_idx', type=int, default=0, help='GPU index.')
    # for env
    parser.add_argument('--env_name', type=str, default='MITCheetah-v1', help='gym environment name.')
    parser.add_argument('--max_episode_steps', type=int, default=500, help='# of maximum episode steps.')
    parser.add_argument('--n_envs', type=int, default=5, help='gym environment name.')
    parser.add_argument('--n_steps', type=int, default=1000, help='update after collecting n_steps.')
    parser.add_argument('--log_lam1', type=float, default=0.0, help='log of lambda1.')
    parser.add_argument('--log_lam2', type=float, default=0.0, help='log of lambda2.')
    parser.add_argument('--log_lam3', type=float, default=0.0, help='log of lambda3.')
    # for networks
    parser.add_argument('--activation', type=str, default='ReLU', help='activation function. ReLU, Tanh, Sigmoid...')
    parser.add_argument('--hidden_dim', type=int, default=512, help='the number of hidden layer\'s node.')
    parser.add_argument('--log_std_init', type=float, default=-3.0, help='log of initial std.')
    # for RL
    parser.add_argument('--discount_factor', type=float, default=0.99, help='discount factor.')
    parser.add_argument('--lr', type=float, default=3e-5, help='learning rate.')
    parser.add_argument('--n_epochs', type=int, default=40, help='update epochs.')
    parser.add_argument('--batch_size', type=int, default=1000, help='batch size.')
    parser.add_argument('--gae_coeff', type=float, default=0.97, help='gae coefficient.')
    parser.add_argument('--ent_coeff', type=float, default=0.001, help='gae coefficient.')
    parser.add_argument('--vf_coeff', type=float, default=0.5, help='loss = policy + value*vf_coeff.')
    # trust region
    parser.add_argument('--max_kl', type=float, default=0.01, help='maximum kl divergence.')
    parser.add_argument('--clip_value', type=float, default=0.05, help='clip value.')
    return parser


def train(args):
    # define env
    env_id = lambda: gym.make(
        args.env_name, max_episode_length=args.max_episode_steps, 
        lam1=np.exp(args.log_lam1),
        lam2=np.exp(args.log_lam2),
        lam3=np.exp(args.log_lam3),
    )
    vec_env = make_vec_env(
        env_id=env_id, n_envs=args.n_envs, seed=args.seed,
        vec_env_cls=DobroSubprocVecEnv3,
        vec_env_kwargs={'args':args, 'start_method':'spawn'},
    )

    # set args value for env
    args.obs_dim = vec_env.observation_space.shape[0]
    args.action_dim = vec_env.action_space.shape[0]
    args.action_bound_min = vec_env.action_space.low
    args.action_bound_max = vec_env.action_space.high
    args.num_costs = vec_env.num_costs
    args.setting_name = 'PPO'

    if args.resume and os.path.exists(f"{args.save_path}.zip"):
        # load
        model = PPO.load(
            args.save_path, env=vec_env, device=args.torch_device, 
            tensorboard_log=args.save_dir, 
            learning_rate=args.lr, 
            n_steps=args.n_steps, batch_size=args.batch_size, n_epochs=args.n_epochs,
            gamma=args.discount_factor, gae_lambda=args.gae_coeff, clip_range=args.clip_value,
            vf_coef=args.vf_coeff, target_kl=args.max_kl, ent_coef=args.ent_coeff,
            verbose=1
        )
    else:
        # define agent
        policy_kwargs = dict(
            activation_fn=eval(f'torch.nn.{args.activation}'),
            net_arch=[dict(pi=[args.hidden_dim]*2, vf=[args.hidden_dim]*2)],
            log_std_init=args.log_std_init,
        )
        model = PPO(
            policy='MlpPolicy', env=vec_env, policy_kwargs=policy_kwargs, learning_rate=args.lr, 
            n_steps=args.n_steps, batch_size=args.batch_size, n_epochs=args.n_epochs,
            gamma=args.discount_factor, gae_lambda=args.gae_coeff, clip_range=args.clip_value,
            vf_coef=args.vf_coeff, device=device, target_kl=args.max_kl, 
            tensorboard_log=args.save_dir, 
            ent_coef=args.ent_coeff, 
            verbose=1
        )

    # wandb
    if args.wandb:
        project_name = f'[Constrained RL] {args.setting_name}'
        wandb.init(
            project=project_name, 
            config=args,
            sync_tensorboard=True,
        )
        run_idx = wandb.run.name.split('-')[-1]
        wandb.run.name = f"{args.name}-{run_idx}"

    # training
    callback = Callback(args)
    model.learn(total_timesteps=args.total_steps, callback=callback)

    # save
    model.save(args.save_path)


def test(args):
    # define env
    env = gym.make(
        args.env_name, max_episode_length=args.max_episode_steps, 
        lam1=np.exp(args.log_lam1),
        lam2=np.exp(args.log_lam2),
        lam3=np.exp(args.log_lam3),
    )
    obs_rms = RunningMeanStd(args.save_dir, env.observation_space.shape[0])

    if args.resume and os.path.exists(f"{args.save_path}.zip"):
        model = PPO.load(
            args.save_path, env=env, device=args.device, 
            tensorboard_log=args.save_dir, 
            learning_rate=args.lr, 
            n_steps=args.n_steps, batch_size=args.batch_size, n_epochs=args.n_epochs,
            gamma=args.discount_factor, gae_lambda=args.gae_coeff, clip_range=args.clip_value,
            vf_coef=args.vf_coeff, target_kl=args.max_kl, ent_coef=args.ent_coeff,
            verbose=1
        )

    else:
        # define agent
        policy_kwargs = dict(
            activation_fn=eval(f'torch.nn.{args.activation}'),
            net_arch=[dict(pi=[args.hidden_dim]*2, vf=[args.hidden_dim]*2)],
            log_std_init=args.log_std_init,
        )
        model = PPO(
            policy='MlpPolicy', env=env, policy_kwargs=policy_kwargs, learning_rate=args.lr, 
            n_steps=args.n_steps, batch_size=args.batch_size, n_epochs=args.n_epochs,
            gamma=args.discount_factor, gae_lambda=args.gae_coeff, clip_range=args.clip_value,
            vf_coef=args.vf_coeff, device=device, target_kl=args.max_kl,
            ent_coef=args.ent_coeff, 
            verbose=1
        )

    for episode in range(10):
        obs = env.reset()
        obs = obs_rms.normalize(obs)

        score = 0.0
        cnt = 0
        while True:
            cnt += 1
            action, _states = model.predict(obs, deterministic=True)
            obs, rewards, dones, info = env.step(action)

            score += info['costs'][3]
            obs = obs_rms.normalize(obs)
            env.render()
            if dones: break

        print(f"score: {score:.3f}\tep_len: {cnt}.")


if __name__ == "__main__":
    parser = getParser()
    args = parser.parse_args()

    # ==== processing args ==== #
    # save_dir
    args.save_dir = f"results/{args.name}_s{args.seed}"
    args.save_path = f"{args.save_dir}/model"
    # device
    if torch.cuda.is_available() and args.device == 'gpu':
        device = torch.device(f'cuda:{args.gpu_idx}')
        print('[torch] cuda is used.')
    else:
        device = torch.device('cpu')
        print('[torch] cpu is used.')
    args.device = device
    # ========================= #

    if args.test:
        test(args)
    else:
        train(args)
