import argparse


import torch

from util import str2bool

parser = argparse.ArgumentParser(description='RL')

# PPO Arguments. 
parser.add_argument(
    '--algo',
    type=str,
    default='ppo',
    choices=['ppo', 'a2c', 'acktr', 'ucb', 'mixreg'],
    help='Which RL algorithm to use.')
parser.add_argument(
    '--lr', 
    type=float, 
    default=1e-4, 
    help='learning rate')
parser.add_argument(
    '--eps',
    type=float,
    default=1e-5,
    help='RMSprop optimizer epsilon')
parser.add_argument(
    '--alpha',
    type=float,
    default=0.99,
    help='RMSprop optimizer apha')
parser.add_argument(
    '--gamma',
    type=float,
    default=0.995,
    help='discount factor for rewards')
parser.add_argument(
    '--use_gae',
    type=str2bool, nargs='?', const=True, default=True,
    help='Use generalized advantage estimator.')
parser.add_argument(
    '--gae_lambda',
    type=float,
    default=0.95,
    help='gae lambda parameter')
parser.add_argument(
    '--entropy_coef',
    type=float,
    default=0.0,
    help='entropy term coefficient')
parser.add_argument(
    '--value_loss_coef',
    type=float,
    default=0.5,
    help='value loss coefficient (default: 0.5)')
parser.add_argument(
    '--max_grad_norm',
    type=float,
    default=0.5,
    help='max norm of gradients)')
parser.add_argument(
    '--normalize_returns',
    type=str2bool, nargs='?', const=True, default=False,
    help='Whether to normalize returns')
parser.add_argument(
    '--use_popart',
    type=str2bool, nargs='?', const=True, default=False,
    help='Whether to normalize values via PopArt.')
parser.add_argument(
    '--seed', 
    type=int, 
    default=1, 
    help='random seed')
parser.add_argument(
    '--num_processes',
    type=int,
    default=32,
    help='how many training CPU processes to use')
parser.add_argument(
    '--num_steps',
    type=int,
    default=256,
    help='number of forward steps in A2C')
parser.add_argument(
    '--ppo_epoch',
    type=int,
    default=5,
    help='number of ppo epochs')
parser.add_argument(
    '--num_mini_batch',
    type=int,
    default=1,
    help='number of batches for ppo')
parser.add_argument(
    '--clip_param',
    type=float,
    default=0.2,
    help='ppo clip parameter')
parser.add_argument(
    '--clip_value_loss',
    type=str2bool,
    default=True,
    help='ppo clip value loss')
parser.add_argument(
    '--num_env_steps',
    type=int,
    default=500000,
    help='number of environment steps to train')

# Architecture arguments.
parser.add_argument(
    '--recurrent_arch',
    type=str,
    default='lstm',
    choices=['gru', 'lstm'],
    help='RNN architecture')
parser.add_argument(
    '--recurrent_agent',
    type=str2bool, nargs='?', const=True, default=True,
    help='disables CUDA training')
parser.add_argument(
    '--recurrent_hidden_size',
    type=int,
    default=256,
    help='Recurrent hidden state size')

# Environment arguments.
parser.add_argument(
    '--env_name',
    type=str,
    default='MultiGrid-BinaryChoice-v0',
    help='Environment to train on')
parser.add_argument(
    '--handle_timelimits',
    type=str2bool, nargs='?', const=True, default=False,
    help="Bootstrap off of early termination states. Requires env to be wrapped by envs.wrappers.TimeLimit.")
parser.add_argument(
    '--singleton_env',
    type=str2bool, nargs='?', const=True, default=False,
    help="When using a fixed env, whether the same environment should also be reused across workers.")
parser.add_argument(
    '--clip_reward',
    type=float,
    default=None,
    help="Clip sparse rewards.")

# Stochastic choice environment arguments.
parser.add_argument(
    '--p',
    type=str,
    default=0.7,
    help="Clip sparse rewards.")
parser.add_argument(
    '--reward_dist',
    type=str,
    default='uniform',
    help="Reward distribution")
parser.add_argument(
    '--stochastic_choice_rewards',
    type=str,
    default='2,10',
    help="Stochastic reward means.")
parser.add_argument(
    '--stochastic_choice_reward_spreads',
    type=str,
    default='2,0',
    help="Stochastic reward standard deviations")
parser.add_argument(
    '--goal_hint_p',
    type=float,
    default=0.0,
    help="Probability of placing a goal hint in each level")
parser.add_argument(
    '--stochastic_choice_use_walls',
    type=str2bool,
    default=False,
    help="Put walls in the stochastic choice maze.")
parser.add_argument(
    '--force_obl_correction',
    type=str2bool,
    default=False,
    help="Force OBL correction.")
parser.add_argument(
    '--use_learned_beliefs',
    type=str2bool,
    default=False,
    help="Use learned belief model for OBL correction.")
parser.add_argument(
    '--fully_observable',
    type=str2bool,
    default=False,
    help="Force OBL correction.")


# PLR arguments.
parser.add_argument(
    "--use_plr",
    type=str2bool, nargs='?', const=True, default=False,
    help='Whether to use PLR.'
)
parser.add_argument(
    "--no_exploratory_grad_updates",
    type=str2bool, nargs='?', const=True, default=False,
    help='Only perform gradient updates for episodes on levels sampled via PLR.'
)
parser.add_argument(
    "--level_replay_score_transform",
    type=str, 
    default='rank', 
    choices=['constant', 'max', 'eps_greedy', 'rank', 'power', 'softmax', 'match', 'match_rank'], 
    help="Level replay scoring strategy")
parser.add_argument(
    "--level_replay_temperature", 
    type=float,
    default=0.1,
    help="Level replay scoring strategy")
parser.add_argument(
    "--level_replay_strategy", 
    type=str,
    default='value_l1',
    choices=['off', 'random', 'uniform', 'sequential',
            'policy_entropy', 'least_confidence', 'min_margin', 
            'gae', 'value_l1', 'signed_value_loss', 'positive_value_loss',
            'grounded_signed_value_loss', 'grounded_positive_value_loss',
            'one_step_td_error', 'alt_advantage_abs',
            'tscl_window'],
    help="Level replay scoring strategy")
parser.add_argument(
    "--level_replay_eps", 
    type=float,
    default=0.05,
    help="Level replay epsilon for eps-greedy sampling")
parser.add_argument(
    "--level_replay_schedule",
    type=str,
    default='proportionate',
    help="Level replay schedule for sampling seen levels")
parser.add_argument(
    "--level_replay_rho",
    type=float, 
    default=1.0,
    help="Minimum size of replay set relative to total number of levels before sampling replays.")
parser.add_argument(
    "--level_replay_prob", 
    type=float,
    default=0.,
    help="Probability of sampling a new level instead of a replay level.")
parser.add_argument(
    "--level_replay_alpha",
    type=float, 
    default=1.0,
    help="Level score EWA smoothing factor")
parser.add_argument(
    "--staleness_coef",
    type=float, 
    default=0.3,
    help="Staleness weighing")
parser.add_argument(
    "--staleness_transform",
    type=str, 
    default='power',
    choices=['max', 'eps_greedy', 'rank', 'power', 'softmax'], 
    help="Staleness normalization transform")
parser.add_argument(
    "--staleness_temperature",
    type=float, 
    default=1.0,
    help="Staleness normalization temperature.")
parser.add_argument(
    "--train_full_distribution",
    type=str2bool, nargs='?', const=True, default=True,
    help='Train on the full distribution of levels.'
)
parser.add_argument(
    "--level_replay_seed_buffer_size",
    type=int, 
    default=4000,
    help="Size of seed buffer, a min-priority queue.")
parser.add_argument(
    "--level_replay_seed_buffer_priority",
    type=str, 
    default='replay_support',
    choices=['score', 'replay_support'], 
    help="How to prioritize seed buffer indices.")

# Hardware arguments.
parser.add_argument(
    '--no_cuda',
    type=str2bool, nargs='?', const=True, default=False,
    help='disables CUDA training')

# Logging arguments.
parser.add_argument(
    "--verbose", 
    type=str2bool, nargs='?', const=True, default=False,
    help="Whether to print logs")
parser.add_argument(
    '--xpid',
    default='latest',
    help='name for the run - prefix to log files')
parser.add_argument(
    '--log_dir',
    default='~/logs/samplr/',
    help='directory to save agent logs')
parser.add_argument(
    '--log_interval',
    type=int,
    default=1,
    help='log interval, one log per n updates')
parser.add_argument(
    "--save_interval", 
    type=int, 
    default=5,
    help="Save model every this many minutes.")
parser.add_argument(
    "--archive_interval", 
    type=int, 
    default=0,
    help="Save an archived model every this many updates.")
parser.add_argument(
    "--weight_log_interval", 
    type=int, 
    default=0,
    help="Save level weights every this many updates")
parser.add_argument(
    "--screenshot_interval", 
    type=int, 
    default=1,
    help="Save screenshot of environment every this many updates.")
parser.add_argument(
    '--render',
    type=str2bool, nargs='?', const=True, default=False,
    help='Render to screen.')
parser.add_argument(
    "--checkpoint", 
    type=str2bool, nargs='?', const=True, default=False,
    help="Begin training from checkpoint.")
parser.add_argument(
    "--disable_checkpoint", 
    type=str2bool, nargs='?', const=True, default=False,
    help="Disable saving checkpoint.")
parser.add_argument(
    '--log_grad_norm',
    type=str2bool, nargs='?', const=True, default=False,
    help="Log the gradient norm of the actor critic")
parser.add_argument(
    '--log_action_complexity',
    type=str2bool, nargs='?', const=True, default=False,
    help="Log action trajectory complexity measures throughout training")
parser.add_argument(
    '--log_replay_complexity',
    type=str2bool, nargs='?', const=True, default=False,
    help="Log complexity of replay levels instead of those generated by adversary.")
parser.add_argument(
    '--test_interval',
    type=int,
    default=10,
    help='Evaluate on test envs every this many updates.')
parser.add_argument(
    '--test_num_episodes',
    type=int,
    default=10,
    help='Number of test episodes per environment')
parser.add_argument(
    '--test_num_processes',
    type=int,
    default=2,
    help='Number of test processes per environment')
parser.add_argument(
    '--test_env_names',
    type=str,
    default='',
    help='Environment to evaluate on')