import os
import sys

from typing import Any, ClassVar, Dict, List, Optional, Tuple, Type, TypeVar
import mo_gymnasium as mo_gym
from stable_baselines3.dqn.dqn import DQN, MaxminMFQ    # DQN: just for baseline, not used
from stable_baselines3.ppo.ppo import MaxminPPO, PPO
from stable_baselines3.sac.sac import MaxminSAC
from stable_baselines3.common.vec_env import SubprocVecEnv

import pdb
import wandb
import random
import argparse
import numpy as np

if __name__ == "__main__":
    prs = argparse.ArgumentParser(formatter_class=argparse.ArgumentDefaultsHelpFormatter,
                                  description="""Deep-Q-Learning MO-deep_sea_treasure""")
    ### 1. Environment parameters
    prs.add_argument("-ename", dest="env_name", type=str, choices=['traffic-big', 'traffic-asym', 'traffic-asym-4000'], help="Env Name. No Default (mandatory) \n")

    prs.add_argument("-rd", dest="reward_dim", type=int, default=16, help="Random seed\n") # 16 for single-intersection, lane-wise reward
    prs.add_argument("-rdp", dest="r_dim_policy", type=int, default=1, help="equals reward_dim if Naive Maxmin DQN, 1 otherwise\n") 

    prs.add_argument("-tt", dest="total_timesteps", type=int, default=100000,
                     help="Total Timesteps. We have total episodes of total_timesteps/(num_seconds/delta_time).\n")

    prs.add_argument("-bf", dest="buffer_size", type=int, default=50000, help="Buffer size\n")  # not needed for ppo
    prs.add_argument("-se", dest="seed", type=int, default=0, help="Random seed\n")

    ### 2. Common algorithm parameters
    prs.add_argument("-o", dest="ours", type=str, default="maxmin_ppo", choices=["ppo_baseline", "baseline", "maxmin_sql", "maxmin_ppo"], help="Choose between ppo_baseline, baseline, maxmin_sql and maxmin_ppo.\n")


    ## Main Q learning rate, or PPO learning rate
    prs.add_argument("-mlr", dest="main_learning_rate", type=float, default=0.003,
                     help="Learning Rate of Main Q function\n")

    ### For varying initial w input
    prs.add_argument("-avwin", dest="stats_window_size", type=int, default=32,  
                     help="The number of episodes to average\n")

    ## averaging window size

    ## soft target update to incorporate updated w information
    prs.add_argument("-tgit", dest="target_update_interval", type=int, default=1, help="Target_update_interval\n")
    prs.add_argument("-tau", dest="tau", type=float, default=0.001, help="Soft Target update ratio\n")

    ## exploration for Q-function
    prs.add_argument("-epinit", dest="exploration_initial_eps", type=float, default=1.0,
                     help="exploration_initial_eps\n")
    prs.add_argument("-epfin", dest="exploration_final_eps", type=float, default=0.05,
                     help="exploration_final_eps\n")
    prs.add_argument("-epfr", dest="exploration_fraction", type=float, default=0.1,
                     help="exploration_fraction\n")

    ### 3. DQN Only. Double DQN argument.
    prs.add_argument("-dbq", dest="double_q", action="store_true", help="Run DDQN or not.\n")

    ### 4. SQL only.
    prs.add_argument("-al", dest="ent_alpha", type=float, default=0.01, help="Entropy coefficient for training and final action selection\n")

    ## alpha scheduling
    prs.add_argument("-alin", dest="ent_alpha_act_init", type=float, default=0.5, help="Entropy coefficient for initial action selection."
                                                                                       "Less than ent_alpha.\n")
    prs.add_argument("-alann", dest="annealing_step", type=int, default=10000, help="Length of Linear Entropy schedule."
                                                                                    "Less than total timesteps.\n") # 4 for single-intersection road


    ## 4-1. Perturbation parameters
    prs.add_argument("-init", dest="init_frac", type=float, default=0.,
                     help="Initialization ratio for Soft-Q Update\n")
    prs.add_argument("-pwlr", dest="perturb_w_learning_rate", type=float, default=0.01, help="Learning Rate of w\n")

    prs.add_argument("-perw", dest="period_cal_w_grad", type=int, default=1, help="Period of calculating w gradient\n")
    prs.add_argument("-pqlr", dest="perturb_q_learning_rate", type=float, default=0.001,
                     help="Learning Rate of each q_w\n")
    prs.add_argument("-pgst", dest="perturb_grad_step", type=int, default=1,
                     help="Number of gradient steps for perturbation\n")
    ### Perturbation parameters
    prs.add_argument("-pqnum", dest="perturb_q_copy_num", type=int, default=20,
                     help="Number of copied q networks for perturbation\n")
    prs.add_argument("-pstd", dest="perturb_std_dev", type=float, default=0.01,
                     help="Standard deviation for perturbed noise\n")

    ## Main Q grad step
    prs.add_argument("-mqst", dest="q_grad_st_after_init", type=int, default=1,
                     help="Number of gradient steps for main Q function after init state\n")

    def parse_input(arg):
        if ',' in arg:
            return [float(item) for item in arg.split(',')]
        elif isinstance(arg, str): # ['uniform', 'dirichlet']
            return arg
        else:
            raise NotImplementedError

    prs.add_argument("-winit", dest="weight_initialize", type=parse_input, nargs='?', default='uniform', help='Initialize Weight w')

    ### 5. PPO only.
    prs.add_argument("-nsteps", dest="n_steps", type=int, default=2048, help="The number of steps to run for each environment per update")
    prs.add_argument("-bsppo", dest="batch_size_ppo", type=int, default=64, help="Minibatch size")
    prs.add_argument("-nepoch", dest="n_epoch_ppo", type=int, default=64, help="Number of epoch when optimizing the surrogate loss")
    prs.add_argument("-entco", dest="ent_loss_coef", type=float, default=0.1, help="entropy loss coefficient")   
    prs.add_argument("-vco", dest="value_loss_coef", type=float, default=0.5, help="value loss coefficient")
    prs.add_argument("-rdwn", dest="r_dim_wise_normalize", type=bool, default=False, help="dimension wise normalization for Advantage")
    prs.add_argument("-ecw", dest="ent_coef_weight", type=float, default=0.1, help="Entropy coefficient for weight")
    prs.add_argument("-nenv", dest="n_envs", type=int, default=4, help="Number of envs for vec_env") 

    ### 6. for w updates
    prs.add_argument("-useci", dest="use_ci", action="store_false", help="use correlation injection for w-update")    # default: True
    prs.add_argument("-usemd", dest="use_md", action="store_false", help="use mirror descent for w-update")    # default: True
    prs.add_argument("-corr", dest="corr", type=str, default='inner_product', choices=['inner_product'], help="How to define correlation")
    prs.add_argument("-reg", dest="reg", type=str, default='kl2', choices=['kl2'], help="regularization function for MD and CI")
    prs.add_argument("-cic", dest="ci_coef", type=float, default=2.0, help="CI coefficient, eta")
    prs.add_argument("-mdc", dest="md_coef", type=float, default=1.0, help="MD coefficient, lambda")
    prs.add_argument("-cp", dest="corr_period", type=int, default=1, help="period for calculating correlation vector")    # calculate everytime
    
    ### 7. SAC only
    def str_or_float(value):
        try:
            return float(value)
        except ValueError:
            if "auto" in value:
                return value  
            raise argparse.ArgumentTypeError(f"Invalid value: {value}")
    def int_or_tuple(value):
        try:
            return int(value)
        except ValueError:
            pass

        try:
            parts = value.strip('()').split(',')
            if len(parts) == 2:
                int_part = int(parts[0].strip())
                str_part = parts[1].strip()
                return (int_part, str_part)
            else:
                raise argparse.ArgumentTypeError("Invalid tuple format, expected (int, str)")
        except ValueError:
            raise argparse.ArgumentTypeError("Invalid value, expected int or (int, str) tuple")
    
    prs.add_argument("-tfsac", dest="train_freq_sac", type=int_or_tuple, default=1, help="train frequency")
    prs.add_argument("-gssac", dest="grad_steps_sac", type=int, default=1, help="How many gradient steps to do after each rollout") 
    prs.add_argument("-ecsac", dest="ent_coef_sac", type=str_or_float, default="auto", help="entropy coefficient, corresponds to ent_alpha in SQL") 
    prs.add_argument("-bssac", dest="batch_size_sac", type=int, default=32, help="Minibatch size")
    prs.add_argument("-tuisac", dest="target_update_interval_sac", type=int, default=1, help="target update interval for SAC")
    
    ### 8. Traffic Only Environment
    prs.add_argument("-dt", dest="delta_time", type=int, default=20, help="Action period\n")
    prs.add_argument("-yt", dest="yellow_time", type=int, default=6, help="Yellow light period\n")
    prs.add_argument("-ns", dest="num_seconds", type=int, default=100000, help="Total Seconds per Episode. "
                                                    " We have total episodes of total_timesteps/(num_seconds/delta_time)\n")
    # ## exponent argument to make the reward nonlinear. For now, it is ignored.
    # prs.add_argument("-rexp", dest="reward_exponent", type=float, default=1,
    #                  help="Reward exponent to make the reward nonlinear\n")
    # prs.add_argument("-rscf", dest="reward_scale_factor", type=float, default=1000.0,
    #                  help="Reward scale factor to make the reward in proper range\n")

    ##################### 7. Others: All of these are set as default values.
    ## Choose scalarization: Maxmin Naive DQN or mean over Target. Only for Baseline DQN.
    ## For our algorithm, there is no effect.
    prs.add_argument("-scal", dest="scalarize", type=str, choices=['min', 'mean'],
                     default='min', help="Choose scalarization for the baseline DQN\n")

    ## Weight decay. set as zero
    prs.add_argument("-wd", dest="weight_decay", type=float, default=0,
                     help="Weight for L2 regularization in Adam optimizer\n")
    ## New version of us. Only applicable for SQLPolicy.
    prs.add_argument("-expw", dest="explicit_w_input", action="store_true", help="Give preference as an input or not.\n")

    ## For now, this is not used.
    # prs.add_argument("-perm", dest="period_main_grad", type=int, default=1, help="Period of updating main gradient. Ours only. \n")

    ## Option for w scheduling. For now, we set this as 'sqrt_inverse'
    prs.add_argument("-wsch", dest="w_schedule_option", type=str, choices=['sqrt_inverse', 'inverse', 'linear'],
                     default='sqrt_inverse', help="Option for w scheduling\n")

    #### New environment: Species control problem
    prs.add_argument("-fr", dest="ifr", type=int, default=2, help="Functional Response for SC\n")
    prs.add_argument("-fnum", dest="ifrnum", type=int, default=2, help="Functional Response Num for SC\n")
    prs.add_argument("-epl", dest="episode_length", type=int, default=5000, help="Episode Length \n")

    args = prs.parse_args()

    r_dim = args.reward_dim
    assert r_dim > 1

    ours = args.ours 
    print("algorithm: ",ours)

    env_name = args.env_name
    print("env_name: ", env_name)

    n_envs = args.n_envs

    # ## wandb initialize
    if ours == 'maxmin_ppo':
        if env_name == 'traffic-big' or env_name == 'traffic-asym' or env_name == 'traffic-asym-4000':
            wandb.init(project="Mo_" + str(env_name), group= str(args.ours) + "_"
                                                    + str(args.use_ci)[0] + "_" + str(args.use_md)[0] + "_" + str(args.corr) + "_" + str(args.ci_coef) + "_" + str(args.md_coef)
                                                    + '_nepoch=' + str(args.n_epoch_ppo)
                                                    + '_nsteps=' + str(args.n_steps)         
                                                    + '_bsppo=' + str(args.batch_size_ppo)  
                                                    + '_mlr=' + str(args.main_learning_rate)
                                                    , job_type="train")
        else:   
            wandb.init(project="Mo_" + str(env_name), group= str(args.ours) + "_"
                                                    + str(args.use_ci)[0] + "_" + str(args.use_md)[0] + "_" + str(args.corr) + "_" + str(args.ci_coef) + "_" + str(args.md_coef)
                                                    + '_nepoch=' + str(args.n_epoch_ppo)
                                                    + '_nsteps=' + str(args.n_steps)         
                                                    + '_bsppo=' + str(args.batch_size_ppo)  
                                                    + '_mlr=' + str(args.main_learning_rate)
                                                    , job_type="train")
    elif ours == 'maxmin_sql':
        if env_name == 'traffic-big' or env_name == 'traffic-asym' or env_name == 'traffic-asym-4000':
            wandb.init(project="Mo_" + str(env_name), group=  str(args.weight_initialize[0]) + '_algo_' + str(args.ours)
                                                  + '_tt=' + str(args.total_timesteps)
                                                  + '_init=' + str(args.init_frac) 
                                                  + '_mlr=' + str(args.main_learning_rate)
                                                  + '_alin=' + str(args.ent_alpha_act_init) 
                                                  + '_al=' + str(args.ent_alpha)
                                                  + '_pqlr=' + str(args.perturb_q_learning_rate)
                                                  + '_pwlr=' + str(args.perturb_w_learning_rate)
                                                  + '_mqst=' + str(args.q_grad_st_after_init)
                                                  + '_tau=' + str(args.tau)
                                                  + '_pstd=' + str(args.perturb_std_dev)
                                                  + '_pgst=' + str(args.perturb_grad_step)
                                                  , job_type="train")
        else:   
            wandb.init(project="Mo_" + str(env_name), group=  str(args.weight_initialize[0]) + '_algo_' + str(args.ours)
                                                  + '_mlr=' + str(args.main_learning_rate)
                                                  + '_pgst=' + str(args.perturb_grad_step)
                                                  , job_type="train")
    elif ours == 'ppo_baseline':
        if env_name == 'traffic-big' or env_name == 'traffic-asym' or env_name == 'traffic-asym-4000':
            wandb.init(project="Mo_" + str(env_name), group=  'tune_' + str(args.ours)
                                                    + '_tt=' + str(args.total_timesteps)[:2]
                                                    , job_type="train")
        else:  
            wandb.init(project="Mo_" + str(env_name), group=  'tune_' + str(args.ours)
                                                    + '_tt=' + str(args.total_timesteps)[:2]
                                                    , job_type="train")
    elif ours == 'baseline':
        if env_name == 'traffic-big' or env_name == 'traffic-asym' or env_name == 'traffic-asym-4000':
            wandb.init(project="Mo_" + str(env_name), group=  str(args.weight_initialize[0]) + '_algo_' + str(args.ours) + '_rdp_' + str(args.r_dim_policy)
                                                    + '_tt=' + str(args.total_timesteps)
                                                    , job_type="train")
        else:   
            wandb.init(project="Mo_" + str(env_name), group=  str(args.weight_initialize[0]) + '_algo_' + str(args.ours) + '_rdp_' + str(args.r_dim_policy)
                                                    + '_tt=' + str(args.total_timesteps)
                                                    , job_type="train")
    else:
        raise NotImplementedError


    wandb.run.name = "seed=" + str(args.seed)

    # random seed ## Already in set_random_seed in utils.py, but set randomness fixed in env
    random.seed(args.seed)
    np.random.seed(args.seed)

    # define env
    if env_name == 'traffic-big':       # Base-4
        if "SUMO_HOME" in os.environ:
            tools = os.path.join(os.environ["SUMO_HOME"], "tools")
            sys.path.append(tools)
        else:
            sys.exit("Please declare the environment variable 'SUMO_HOME'")
        import traci

        from sumo_rl.environment.env import SumoEnvironment

        if ours == 'maxmin_ppo' or ours == 'ppo_baseline':    # use vec_env
            def make_env(seed_offset):  
                def _init():
                    sumoenv = SumoEnvironment(
                        net_file="nets/big-intersection/big-intersection.net.xml",
                        route_file="nets/big-intersection/routes.rou_asym_10000_long.xml",
                        single_agent=True,
                        use_gui=False, #False
                        delta_time=args.delta_time,
                        yellow_time=args.yellow_time,
                        num_seconds=args.num_seconds,
                        sumo_seed=args.seed + seed_offset,) 
                    return sumoenv
                return _init
            env = SubprocVecEnv([make_env(i) for i in range(n_envs)])
        else:
            env = SumoEnvironment(
                net_file="nets/big-intersection/big-intersection.net.xml",
                route_file="nets/big-intersection/routes.rou_asym_10000_long.xml",
                single_agent=True,
                use_gui=False, #False
                delta_time=args.delta_time,
                yellow_time=args.yellow_time,
                num_seconds=args.num_seconds,
                sumo_seed=args.seed,
        )
    elif env_name == 'traffic-asym':    # Asym-4
        if "SUMO_HOME" in os.environ:
            tools = os.path.join(os.environ["SUMO_HOME"], "tools")
            sys.path.append(tools)
        else:
            sys.exit("Please declare the environment variable 'SUMO_HOME'")
        import traci

        from sumo_rl.environment.env import SumoEnvironment

        if ours == 'maxmin_ppo' or ours == 'ppo_baseline':    # use vec_env
            def make_env(seed_offset):  
                def _init():
                    sumoenv = SumoEnvironment(
                        net_file="nets/big-intersection/big-intersection.net.xml",      # use the same network as baseline
                        route_file="nets/big-intersection/routes.rou_2p0s_4000.xml",      # but car generation and turn prob are asymmetric
                        single_agent=True,
                        use_gui=False, #False
                        delta_time=args.delta_time,
                        yellow_time=args.yellow_time,
                        num_seconds=args.num_seconds,
                        sumo_seed=args.seed + seed_offset,) 
                    return sumoenv
                return _init
            env = SubprocVecEnv([make_env(i) for i in range(n_envs)])
        else:
            env = SumoEnvironment(
                net_file="nets/big-intersection/big-intersection.net.xml",
                route_file="nets/big-intersection/routes.rou_2p0s_4000.xml",
                single_agent=True,
                use_gui=False, #False
                delta_time=args.delta_time,
                yellow_time=args.yellow_time,
                num_seconds=args.num_seconds,
                sumo_seed=args.seed,
        )
    elif env_name == 'traffic-asym-4000':   # Asym-16 
        if "SUMO_HOME" in os.environ:
            tools = os.path.join(os.environ["SUMO_HOME"], "tools")
            sys.path.append(tools)
        else:
            sys.exit("Please declare the environment variable 'SUMO_HOME'")
        import traci

        from sumo_rl.environment.env import SumoEnvironment

        if ours == 'maxmin_ppo' or ours == 'ppo_baseline':    # use vec_env
            def make_env(seed_offset):  
                def _init():
                    sumoenv = SumoEnvironment(
                        net_file="nets/big-intersection/big-intersection.net.xml",      # use the same network as baseline
                        route_file="nets/big-intersection/routes.rou_2p0s_4000.xml",      # but car generation and turn prob are asymmetric
                        single_agent=True,
                        use_gui=False, #False
                        delta_time=args.delta_time,
                        yellow_time=args.yellow_time,
                        num_seconds=args.num_seconds,
                        reward_fn="waiting-time-lane",
                        sumo_seed=args.seed + seed_offset,) 
                    return sumoenv
                return _init
            env = SubprocVecEnv([make_env(i) for i in range(n_envs)])
        else:
            env = SumoEnvironment(
                net_file="nets/big-intersection/big-intersection.net.xml",
                route_file="nets/big-intersection/routes.rou_2p0s_4000.xml",
                single_agent=True,
                use_gui=False, #False
                delta_time=args.delta_time,
                yellow_time=args.yellow_time,
                num_seconds=args.num_seconds,
                reward_fn="waiting-time-lane",
                sumo_seed=args.seed,
        )
    else:
        raise NotImplementedError

    if ours == "maxmin_sql":
        model = MaxminMFQ(
            env=env,
            env_name = env_name,
            policy="SQLPolicy",
            learning_rate=args.main_learning_rate,
            learning_starts=0,
            train_freq=1,
            target_update_interval=args.target_update_interval,
            tau=args.tau,
            exploration_initial_eps=args.exploration_initial_eps,
            exploration_final_eps=args.exploration_final_eps,
            exploration_fraction=args.exploration_fraction,
            verbose=1,
            seed=args.seed,
            r_dim=r_dim,
            r_dim_policy=1,
            # not r_dim. output of policy network(Q) is ac_dim*r_dim_policy. We set r_dim_policy=1 for the proposed method.
            buffer_size=args.buffer_size,
            ent_alpha=args.ent_alpha,  ##### newly added
            weight_decay=args.weight_decay, ##### newly added - Adam
            ####### perturbation parameters
            soft_q_init_fraction=args.init_frac,
            perturb_w_learning_rate=args.perturb_w_learning_rate,
            period_cal_w_grad=args.period_cal_w_grad,
            perturb_q_copy_num=args.perturb_q_copy_num,
            perturb_std_dev=args.perturb_std_dev,
            perturb_q_learning_rate=args.perturb_q_learning_rate,
            perturb_grad_step=args.perturb_grad_step,
            q_grad_st_after_init=args.q_grad_st_after_init,
            ###
            explicit_w_input=args.explicit_w_input,
            weight_initialize=args.weight_initialize,
            # period_main_grad=args.period_main_grad,
            w_schedule_option=args.w_schedule_option,
            ##
            stats_window_size=args.stats_window_size,
            ## alpha scheduling for SQL variants
            ent_alpha_act_init=args.ent_alpha_act_init,
            annealing_step=args.annealing_step,
        )

        print(f'env name: {model.env_name}')
        model.learn(total_timesteps=args.total_timesteps,
                    tb_log_name="MaxminMFQ")
        
    elif ours == "maxmin_ppo":
        model = MaxminPPO(
            env = env,
            env_name = env_name,
            policy = 'MlpPolicy',
            learning_rate = args.main_learning_rate,
            n_steps = args.n_steps, #
            batch_size = args.batch_size_ppo, #
            n_epochs = args.n_epoch_ppo, #
            # gamma: float = 0.99, ## default
            # gae_lambda: float = 0.95, ## default
            # clip_range: Union[float, Schedule] = 0.2, ## default
            # clip_range_vf: Union[None, float, Schedule] = None, ## default
            # normalize_advantage: bool = True, ## default
            ent_coef = args.ent_loss_coef, # default: 0.0
            vf_coef = args.value_loss_coef, # : float = 0.5, ## default
            # max_grad_norm: float = 0.5, ## default
            # use_sde: bool = False, ## default
            # sde_sample_freq: int = -1, ## default
            # target_kl: Optional[float] = None, ## default
            stats_window_size = args.stats_window_size, ### neglect
            # tensorboard_log: Optional[str] = None, ## default
            # policy_kwargs: Optional[Dict[str, Any]] = None, ## default
            verbose = 1, # defalut: int = 0, #
            seed = args.seed, #
            # device: Union[th.device, str] = "auto", ## default
            # _init_setup_model: bool = True, ## default
            r_dim = r_dim, ### Newly added
            r_dim_wise_normalize = args.r_dim_wise_normalize, ### Newly added
            # env_name: Optional[str] = None,
            ### Newly added
            weight_initialize = args.weight_initialize,
            ent_coef_weight = args.ent_coef_weight,
            use_md = args.use_md,
            use_ci = args.use_ci,
            corr = args.corr,
            reg = args.reg,
            ci_coef = args.ci_coef,
            md_coef = args.md_coef,
        )
        print(f'env name: {model.env_name}\n')

        model.learn(total_timesteps=args.total_timesteps,
                    tb_log_name="MaxminPPO")

    elif ours == "baseline":
        model = DQN(
            env=env,
            # env_name = env_name,
            policy="DQNPolicy",
            learning_rate=args.main_learning_rate,
            learning_starts=0,
            train_freq=1,
            target_update_interval=args.target_update_interval,
            tau=args.tau,
            exploration_initial_eps=args.exploration_initial_eps,
            exploration_final_eps=args.exploration_final_eps,
            exploration_fraction=args.exploration_fraction,
            verbose=1,
            seed=args.seed,
            r_dim=r_dim,
            r_dim_policy=args.r_dim_policy,
            # output of policy network(Q) is ac_dim*r_dim_policy. We set r_dim_policy=1 for the proposed method.
            buffer_size=args.buffer_size,
            weight_decay=args.weight_decay,
            double_q=args.double_q,
            scalarize=args.scalarize,
            ##
            stats_window_size=args.stats_window_size,
        )
        model.learn(
            total_timesteps=args.total_timesteps)
    elif ours == 'ppo_baseline':
        model = PPO(
            policy = 'MlpPolicy', #
            env = env, #
            learning_rate = args.main_learning_rate,
            n_steps = args.n_steps, #
            batch_size = args.batch_size_ppo, #
            n_epochs = args.n_epoch_ppo, #
            # gamma: float = 0.99, ## default
            # gae_lambda: float = 0.95, ## default
            # clip_range: Union[float, Schedule] = 0.2, ## default
            # clip_range_vf: Union[None, float, Schedule] = None, ## default
            # normalize_advantage: bool = True, ## default
            ent_coef = args.ent_loss_coef, # default: 0.0
            vf_coef = args.value_loss_coef, # : float = 0.5, ## default
            # max_grad_norm: float = 0.5, ## default
            # use_sde: bool = False, ## default
            # sde_sample_freq: int = -1, ## default
            # target_kl: Optional[float] = None, ## default
            stats_window_size = args.stats_window_size, ### neglect
            # tensorboard_log: Optional[str] = None, ## default
            # policy_kwargs: Optional[Dict[str, Any]] = None, ## default
            verbose = 1, # default: 0
            seed = args.seed, #
            # device: Union[th.device, str] = "cpu", ## default
            # _init_setup_model: bool = True, ## default
            r_dim = r_dim, ### Newly added
            r_dim_wise_normalize = args.r_dim_wise_normalize, ### Newly added
            env_name = env_name,
        )
        model.learn(total_timesteps=args.total_timesteps,
                    tb_log_name="GGF-PPO")
    else:
      raise Exception(" choose from ppo_baseline, baseline, maxmin_sql and maxmin_ppo")

    wandb.finish()