# ===== add python path ===== #
import glob
import sys
import os
PATH = os.getcwd()
for dir_idx, dir_name in enumerate(PATH.split('/')):
    dir_path = '/'.join(PATH.split('/')[:(dir_idx+1)])
    file_list = [os.path.basename(sub_dir) for sub_dir in glob.glob(f"{dir_path}/.*")]
    if '.git_package' in file_list:
        PATH = dir_path
        break
if not PATH in sys.path:
    sys.path.append(PATH)
# =========================== #

from utils.vectorize import DobroSubprocVecEnv3
from utils.normalize import RunningMeanStd
from utils.vectorize import Callback
import utils.register

from stable_baselines3.common.env_util import make_vec_env
from sb3_contrib import TQC
import numpy as np
import argparse
import torch
import wandb
import gym

def getParser():
    parser = argparse.ArgumentParser(description='RL')
    # common
    parser.add_argument('--wandb',  action='store_true', help='use wandb?')
    parser.add_argument('--test',  action='store_true', help='test or train?')
    parser.add_argument('--resume',  action='store_true', help='resume?')
    parser.add_argument('--device', type=str, default='gpu', help='gpu or cpu.')
    parser.add_argument('--name', type=str, default='RL', help='save name.')
    parser.add_argument('--save_freq', type=int, default=int(1e6), help='# of time steps for save.')
    parser.add_argument('--total_steps', type=int, default=int(5e6), help='total training steps.')
    parser.add_argument('--seed', type=int, default=1, help='seed number.')
    parser.add_argument('--gpu_idx', type=int, default=0, help='GPU index.')
    # for env
    parser.add_argument('--env_name', type=str, default='MITCheetah-v1', help='gym environment name.')
    parser.add_argument('--max_episode_steps', type=int, default=500, help='# of maximum episode steps.')
    parser.add_argument('--n_envs', type=int, default=5, help='gym environment name.')
    parser.add_argument('--n_steps', type=int, default=1000, help='update after collecting n_steps.')
    parser.add_argument('--n_start_steps', type=int, default=0, help='update after start_steps.')
    parser.add_argument('--log_lam1', type=float, default=0.0, help='log of lambda1.')
    parser.add_argument('--log_lam2', type=float, default=0.0, help='log of lambda2.')
    parser.add_argument('--log_lam3', type=float, default=0.0, help='log of lambda3.')
    # for networks
    parser.add_argument('--activation', type=str, default='ReLU', help='activation function. ReLU, Tanh, Sigmoid...')
    parser.add_argument('--hidden_dim', type=int, default=512, help='the number of hidden layer\'s node.')
    parser.add_argument('--log_std_init', type=float, default=-3.0, help='log of initial std.')
    parser.add_argument('--len_replay_buffer', type=int, default=int(1e6), help='length of replay buffer.')
    # for RL
    parser.add_argument('--discount_factor', type=float, default=0.99, help='discount factor.')
    parser.add_argument('--lr', type=float, default=3e-4, help='learning rate.')
    parser.add_argument('--ent_coeff', type=float, default=0.001, help='gae coefficient.')
    return parser


def train(args):
    # define env
    env_id = lambda: gym.make(
        args.env_name, max_episode_length=args.max_episode_steps, 
        lam1=np.exp(args.log_lam1),
        lam2=np.exp(args.log_lam2),
        lam3=np.exp(args.log_lam3),
    )
    vec_env = make_vec_env(
        env_id=env_id, n_envs=args.n_envs, seed=args.seed,
        vec_env_cls=DobroSubprocVecEnv3,
        vec_env_kwargs={'args':args, 'start_method':'spawn'},
    )

    # set args value for env
    args.obs_dim = vec_env.observation_space.shape[0]
    args.action_dim = vec_env.action_space.shape[0]
    args.action_bound_min = vec_env.action_space.low
    args.action_bound_max = vec_env.action_space.high
    args.num_costs = vec_env.num_costs
    args.setting_name = 'TQC'

    # declare agent
    if args.resume and os.path.exists(f"{args.save_path}.zip"):
        # load
        model = TQC.load(
            args.save_path,
            env=vec_env, 
            seed=args.seed, 
            gamma=args.discount_factor, 
            tensorboard_log=args.save_dir, 
            device=args.device, 
            policy="MlpPolicy", 
            policy_kwargs=policy_kwargs,
            learning_rate=args.lr,
            learning_starts=args.n_start_steps, 
            buffer_size=args.len_replay_buffer,
            ent_coef=args.ent_coeff,
            verbose=1, 
        )
    else:
        # define agent
        policy_kwargs = dict(
            activation_fn=eval(f'torch.nn.{args.activation}'),
            net_arch=[args.hidden_dim]*2,
            log_std_init=args.log_std_init,
        )
        model = TQC(
            env=vec_env, 
            seed=args.seed, 
            gamma=args.discount_factor, 
            tensorboard_log=args.save_dir, 
            device=args.device, 
            policy="MlpPolicy", 
            policy_kwargs=policy_kwargs,
            learning_rate=args.lr,
            learning_starts=args.n_start_steps, 
            buffer_size=args.len_replay_buffer,
            ent_coef=args.ent_coeff,
            # ent_coef=f"auto_{args.ent_coeff}",
            target_entropy=-1*args.action_dim,
            train_freq=100,
            gradient_steps=100,
            verbose=1, 
        )
        # to force log_std_init
        def initWeights(m):
            if isinstance(m, torch.nn.Linear):
                m.weight.data.normal_(0, 0.01)
                m.bias.data.normal_(args.log_std_init, 0.01)
        model.actor.log_std.apply(initWeights)

    # wandb
    if args.wandb:
        project_name = f'[Constrained RL] {args.setting_name}'
        wandb.init(
            project=project_name, 
            config=args,
            sync_tensorboard=True,
        )
        run_idx = wandb.run.name.split('-')[-1]
        wandb.run.name = f"{args.name}-{run_idx}"

    # training
    callback = Callback(args)
    model.learn(total_timesteps=args.total_steps, callback=callback)

    # save
    model.save(args.save_path)


def test(args):
    # define env
    env = gym.make(
        args.env_name, max_episode_length=args.max_episode_steps, 
        lam1=np.exp(args.log_lam1),
        lam2=np.exp(args.log_lam2),
        lam3=np.exp(args.log_lam3),
    )
    obs_rms = RunningMeanStd(args.save_dir, env.observation_space.shape[0])

    # set args value for env
    args.obs_dim = env.observation_space.shape[0]
    args.action_dim = env.action_space.shape[0]
    args.action_bound_min = env.action_space.low
    args.action_bound_max = env.action_space.high
    args.num_costs = env.num_costs
    args.setting_name = 'TQC'

    policy_kwargs = dict(
        activation_fn=eval(f'torch.nn.{args.activation}'),
        net_arch=[args.hidden_dim]*2,
        log_std_init=args.log_std_init,
        use_sde=False,
    )
    if args.resume and os.path.exists(f"{args.save_path}.zip"):
        # load
        model = TQC.load(
            args.save_path,
            env=env, 
            seed=args.seed, 
            gamma=args.discount_factor, 
            tensorboard_log=args.save_dir, 
            device=args.device, 
            policy="MlpPolicy", 
            policy_kwargs=policy_kwargs,
            learning_rate=args.lr,
            learning_starts=args.n_start_steps, 
            buffer_size=args.len_replay_buffer,
            ent_coef=args.ent_coeff,
            verbose=1, 
        )
    else:
        # define agent
        model = TQC(
            env=env, 
            seed=args.seed, 
            gamma=args.discount_factor, 
            tensorboard_log=args.save_dir, 
            device=args.device, 
            policy="MlpPolicy", 
            policy_kwargs=policy_kwargs,
            learning_rate=args.lr,
            learning_starts=args.n_start_steps, 
            buffer_size=args.len_replay_buffer,
            ent_coef=args.ent_coeff,
            # ent_coef=f"auto_{args.ent_coeff}",
            target_entropy=-1*args.action_dim,
            train_freq=100,
            gradient_steps=100,
            verbose=1, 
        )
        # to force log_std_init
        def initWeights(m, init_bias=0.0):
            if isinstance(m, torch.nn.Linear):
                m.weight.data.normal_(0, 0.01)
                m.bias.data.normal_(init_bias, 0.01)
        for m_idx, module in enumerate(model.actor.children()):
            if m_idx == 3:
                print("init bias:", args.log_std_init)
                initializer = lambda m: initWeights(m, init_bias=args.log_std_init)
            else:
                initializer = lambda m: initWeights(m)
            module.apply(initializer)


    for episode in range(10):
        obs = env.reset()
        obs = obs_rms.normalize(obs)

        env.unwrapped.cmd_lin_vel = np.array([1.0, 0.0, 0.0])
        env.unwrapped.cmd_ang_vel = np.array([0.0, 0.0, 0.0])

        score = 0.0
        cnt = 0
        while True:
            cnt += 1
            action, _states = model.predict(obs, deterministic=True)
            # action, _states = model.predict(obs, deterministic=False)
            obs, rewards, dones, info = env.step(action)

            score += info['costs'][3]
            obs = obs_rms.normalize(obs)
            env.render()
            if dones: break

        print(f"score: {score:.3f}\tep_len: {cnt}.")


if __name__ == "__main__":
    parser = getParser()
    args = parser.parse_args()

    # ==== processing args ==== #
    # save_dir
    args.save_dir = f"results/{args.name}_s{args.seed}"
    args.save_path = f"{args.save_dir}/model"
    # device
    if torch.cuda.is_available() and args.device == 'gpu':
        device = torch.device(f'cuda:{args.gpu_idx}')
        print('[torch] cuda is used.')
    else:
        device = torch.device('cpu')
        print('[torch] cpu is used.')
    args.device = device
    # ========================= #

    if args.test:
        test(args)
    else:
        train(args)
