import argparse
from utils.helpers import boolean_argument

def get_args(rest_args):
    parser = argparse.ArgumentParser()

    # --- GENERAL ---

    # training parameters
    parser.add_argument('--num_frames', type=int, default=5e7, help='number of frames to train')
    parser.add_argument('--max_rollouts_per_task', type=int, default=4)
    parser.add_argument('--exp_label', default='SDVT_GMVAE', help='label (typically name of method)')
    parser.add_argument('--env_name', default='ML1PushEnv-v2', help='environment to train on')

    # --- POLICY ---

    #using separate GRU
    parser.add_argument('--policy_separate_gru', type=boolean_argument, default=False, help='condition policy on state')

    # what to pass to the policy (note this is after the encoder)
    parser.add_argument('--pass_state_to_policy', type=boolean_argument, default=True, help='condition policy on state')
    parser.add_argument('--pass_latent_to_policy', type=boolean_argument, default=True, help='condition policy on VAE latent')
    parser.add_argument('--pass_belief_to_policy', type=boolean_argument, default=False, help='condition policy on ground-truth belief')
    parser.add_argument('--pass_task_to_policy', type=boolean_argument, default=False, help='condition policy on ground-truth task description')
    parser.add_argument('--pass_prob_to_policy', type=boolean_argument, default=True, help='condition policy on mixture VAEs probs, only when VAE mixture num >1')

    # using separate encoders for the different inputs ("None" uses no encoder)
    parser.add_argument('--policy_state_embedding_dim', type=int, default=64)
    parser.add_argument('--policy_latent_embedding_dim', type=int, default=64)
    parser.add_argument('--policy_belief_embedding_dim', type=int, default=64)
    parser.add_argument('--policy_task_embedding_dim', type=int, default=None)
    parser.add_argument('--policy_prob_embedding_dim', type=int, default=64, help = 'only when VAE mixture num >1')

    # normalising (inputs/rewards/outputs)
    parser.add_argument('--norm_state_for_policy', type=boolean_argument, default=True, help='normalise state input')
    parser.add_argument('--norm_latent_for_policy', type=boolean_argument, default=True, help='normalise latent input')
    parser.add_argument('--norm_belief_for_policy', type=boolean_argument, default=True, help='normalise belief input')
    parser.add_argument('--norm_task_for_policy', type=boolean_argument, default=True, help='normalise task input')
    parser.add_argument('--norm_prob_for_policy', type=boolean_argument, default=True, help='normalise prob input, only when VAE mixture num >1')
    parser.add_argument('--norm_rew_for_policy', type=boolean_argument, default=True, help='normalise rew for RL train')
    parser.add_argument('--norm_actions_pre_sampling', type=boolean_argument, default=False, help='normalise policy output')
    parser.add_argument('--norm_actions_post_sampling', type=boolean_argument, default=False, help='normalise policy output')

    # network
    parser.add_argument('--policy_layers', nargs='+', default=[256, 256])
    parser.add_argument('--policy_activation_function', type=str, default='tanh', help='tanh/relu/leaky-relu')
    parser.add_argument('--policy_initialisation', type=str, default='normc', help='normc/orthogonal')
    parser.add_argument('--policy_anneal_lr', type=boolean_argument, default=False, help='anneal LR over time')

    # RL algorithm
    parser.add_argument('--policy', type=str, default='ppo', help='choose: a2c, ppo')
    parser.add_argument('--policy_optimiser', type=str, default='adam', help='choose: rmsprop, adam')

    # PPO specifi
    parser.add_argument('--ppo_num_epochs', type=int, default=5, help='number of epochs per PPO update')
    parser.add_argument('--ppo_num_minibatch', type=int, default=10, help='number of minibatches to split the data')
    parser.add_argument('--ppo_use_huberloss', type=boolean_argument, default=True, help='use huberloss instead of MSE')
    parser.add_argument('--ppo_use_clipped_value_loss', type=boolean_argument, default=True, help='clip value loss')
    parser.add_argument('--ppo_clip_param', type=float, default=0.1, help='clamp param')
    parser.add_argument('--ppo_disc',  type=boolean_argument, default=False, help='dimension-wise clipping')

    # other hyperparameters
    parser.add_argument('--lr_policy', type=float, default=7e-4, help='learning rate (default: 7e-4)')
    parser.add_argument('--num_processes', type=int, default=10,
                        help='how many training CPU processes / parallel environments to use (default: 16)')
    parser.add_argument('--policy_num_steps', type=int, default=5000,
                        help='number of env steps to do (per process) before updating')
    parser.add_argument('--policy_eps', type=float, default=1e-8, help='optimizer epsilon (1e-8 for ppo, 1e-5 for a2c)')
    parser.add_argument('--policy_init_std', type=float, default=1.0, help='only used for continuous actions')
    parser.add_argument('--policy_min_std', type=float, default=0.5, help='minimum std of policy only used for continuous actions')
    parser.add_argument('--policy_max_std', type=float, default=1.5, help='maximum std of policy only used for continuous actions')
    parser.add_argument('--policy_value_loss_coef', type=float, default=0.5, help='value loss coefficient')
    parser.add_argument('--policy_entropy_coef', type=float, default=0.001, help='entropy term coefficient')
    parser.add_argument('--policy_gamma', type=float, default=0.99, help='discount factor for rewards')
    parser.add_argument('--policy_use_gae', type=boolean_argument, default=True,
                        help='use generalized advantage estimation')
    parser.add_argument('--policy_tau', type=float, default=0.90, help='gae parameter')
    parser.add_argument('--use_proper_time_limits', type=boolean_argument, default=True,
                        help='treat timeout and death differently (important in mujoco)')
    parser.add_argument('--policy_max_grad_norm', type=float, default=0.5, help='max norm of gradients')
    parser.add_argument('--encoder_max_grad_norm', type=float, default=1.0, help='max norm of gradients')
    parser.add_argument('--decoder_max_grad_norm', type=float, default=1.0, help='max norm of gradients')

    # --- VAE TRAINING ---
    parser.add_argument('--dropout_rate', type=float, default=0.7, help='dropout rate for non-latent input of decoder')
    # general
    parser.add_argument('--lr_vae', type=float, default=0.001)
    parser.add_argument('--size_vae_buffer', type=int, default=1000,
                        help='how many trajectories (!) to keep in VAE buffer')
    parser.add_argument('--precollect_len', type=int, default=5000,
                        help='how many frames to pre-collect before training begins (useful to fill VAE buffer)')
    parser.add_argument('--vae_buffer_add_thresh', type=float, default=1,
                        help='probability of adding a new trajectory to buffer')
    parser.add_argument('--vae_batch_num_trajs', type=int, default=10,
                        help='how many trajectories to use for VAE update')
    parser.add_argument('--tbptt_stepsize', type=int, default=50,
                        help='stepsize for truncated backpropagation through time; None uses max (horizon of BAMDP)')
    parser.add_argument('--vae_subsample_elbos', type=int, default=100,
                        help='for how many timesteps to compute the ELBO; None uses all')
    parser.add_argument('--vae_subsample_decodes', type=int, default=100,
                        help='number of reconstruction terms to subsample; None uses all')
    parser.add_argument('--vae_avg_elbo_terms', type=boolean_argument, default=True,
                        help='Average ELBO terms (instead of sum)')
    parser.add_argument('--vae_avg_reconstruction_terms', type=boolean_argument, default=True,
                        help='Average reconstruction terms (instead of sum)')
    parser.add_argument('--num_vae_updates', type=int, default=10,
                        help='how many VAE update steps to take per meta-iteration')
    parser.add_argument('--pretrain_len', type=int, default=0, help='for how many updates to pre-train the VAE')
    parser.add_argument('--kl_weight', type=float, default=0.1, help='weight for the KL term')

    parser.add_argument('--split_batches_by_task', type=boolean_argument, default=False,
                        help='split batches up by task (to save memory or if tasks are of different length)')
    parser.add_argument('--split_batches_by_elbo', type=boolean_argument, default=False,
                        help='split batches up by elbo term (to save memory of if ELBOs are of different length)')

    # --- GMVAE SPECIFIC ---
    # GMVAE latent structure: w (style), z (skill), y (analytic mixture weights)
    parser.add_argument('--vae_mixture_num', type=int, default=5,
                        help='how many mixture components K to use for GMVAE')
    
    # GMVAE loss coefficients (separate control for each KL term)
    parser.add_argument('--kl_w_loss_coeff', type=float, default=1.0,
                        help='coefficient for KL(q(w|h)||p(w)) loss term')
    parser.add_argument('--kl_z_loss_coeff', type=float, default=1.0,
                        help='coefficient for KL(q(z|h)||p(z|w,y)) loss term')
    parser.add_argument('--kl_y_loss_coeff', type=float, default=1.0,
                        help='coefficient for categorical y loss term')
    
    # Policy integration
    parser.add_argument('--pass_w_to_policy', type=boolean_argument, default=True,
                        help='condition policy on w (style latent)')
    parser.add_argument('--policy_w_embedding_dim', type=int, default=64, 
                        help='embedding dimension for w latent when passed to policy')
    parser.add_argument('--norm_w_for_policy', type=boolean_argument, default=True, 
                        help='normalise w input for policy')
    
    # Legacy/backward compatibility (deprecated for GMVAE)
    parser.add_argument('--gauss_loss_coeff', type=float, default=1.0,
                        help='[DEPRECATED for GMVAE] use kl_w_loss_coeff and kl_z_loss_coeff instead')
    parser.add_argument('--cat_loss_coeff', type=float, default=1.0,
                        help='[DEPRECATED for GMVAE] use kl_y_loss_coeff instead')
    
    # Other mixture settings
    parser.add_argument('--gumbel_temperature', type=float, default=1.0,
                        help='Gumbel softmax temperature, when nearly 0, hardmax')
    parser.add_argument('--occ_loss_coeff', type=float, default=0.0, help='Occupancy regularization coefficient')
    parser.add_argument('--occ_loss_type', type=str, default='exp',
                        help='choose: linear, log, exp')

    # - encoder
    parser.add_argument('--action_embedding_size', type=int, default=16)
    parser.add_argument('--state_embedding_size', type=int, default=32)
    parser.add_argument('--reward_embedding_size', type=int, default=16)
    parser.add_argument('--encoder_layers_before_gru', nargs='+', type=int, default=[])
    parser.add_argument('--encoder_gru_hidden_size', type=int, default=256, help='dimensionality of RNN hidden state')
    parser.add_argument('--encoder_layers_after_gru', nargs='+', type=int, default=[])
    parser.add_argument('--latent_dim', type=int, default=5, help='dimensionality of latent space')

    # --- encoder: RNN type ---
    parser.add_argument('--rnn_type', default='gru', help='gru or block-rnn')

    # - decoder: rewards
    parser.add_argument('--decode_reward', type=boolean_argument, default=True, help='use reward decoder')
    parser.add_argument('--normalise_rew_targets', type=boolean_argument, default=False, help='divide reward targets by largest rew seen')
    parser.add_argument('--rew_loss_coeff', type=float, default=10, help='weight for state loss (vs reward loss)')
    parser.add_argument('--input_prev_state', type=boolean_argument, default=True, help='use prev state for rew pred')
    parser.add_argument('--input_action', type=boolean_argument, default=True, help='use prev action for rew pred')
    parser.add_argument('--reward_decoder_layers', nargs='+', type=int, default=[64, 64, 32])
    parser.add_argument('--multihead_for_reward', type=boolean_argument, default=False,
                        help='one head per reward pred (i.e. per state)')
    parser.add_argument('--rew_pred_type', type=str, default='deterministic',
                        help='choose: '
                             'bernoulli (predict p(r=1|s))'
                             'categorical (predict p(r=1|s) but use softmax instead of sigmoid)'
                             'deterministic (treat as regression problem)')

    # - decoder: state transitions
    parser.add_argument('--decode_state', type=boolean_argument, default=True, help='use state decoder')
    parser.add_argument('--state_loss_coeff', type=float, default=1000, help='weight for state loss')
    parser.add_argument('--state_decoder_layers', nargs='+', type=int, default=[64, 64, 32])
    parser.add_argument('--state_pred_type', type=str, default='deterministic', help='choose: deterministic, gaussian')

    # - decoder: ground-truth task ("varibad oracle", after Humplik et al. 2019)
    parser.add_argument('--decode_task', type=boolean_argument, default=False, help='use task decoder')
    parser.add_argument('--task_loss_coeff', type=float, default=1.0, help='weight for task loss')
    parser.add_argument('--task_decoder_layers', nargs='+', type=int, default=[64, 32])
    parser.add_argument('--task_pred_type', type=str, default='task_id', help='choose: task_id, task_description')

    # --- ABLATIONS ---

    # for the VAE
    parser.add_argument('--disable_decoder', type=boolean_argument, default=False,
                        help='train without decoder')
    parser.add_argument('--disable_stochasticity_in_latent', type=boolean_argument, default=False,
                        help='use auto-encoder (non-variational)')
    parser.add_argument('--disable_kl_term', type=boolean_argument, default=False,
                        help='dont use the KL regularising loss term')
    parser.add_argument('--decode_only_past', type=boolean_argument, default=False,
                        help='only decoder past observations, not the future')
    parser.add_argument('--kl_to_gauss_prior', type=boolean_argument, default=False,
                        help='KL term in ELBO to fixed Gaussian prior (instead of prev approx posterior)')

    # --- Extrapolate ----
    parser.add_argument('--vae_extrapolate', type=boolean_argument, default=True,
                        help='Use extrapolate structured VAE')
    parser.add_argument('--ext_loss_coeff', type=float, default=10.0,
                        help='extrapolation loss coefficient')

    # combining vae and RL loss
    parser.add_argument('--rlloss_through_encoder', type=boolean_argument, default=False,
                        help='backprop rl loss through encoder')
    parser.add_argument('--add_nonlinearity_to_latent', type=boolean_argument, default=False,
                        help='Use relu before feeding latent to policy')
    parser.add_argument('--vae_loss_coeff', type=float, default=1.0,
                        help='weight for VAE loss (vs RL loss)')

    # for the policy training
    parser.add_argument('--sample_embeddings', type=boolean_argument, default=False,
                        help='sample embedding for policy, instead of full belief')

    # for other things
    parser.add_argument('--disable_metalearner', type=boolean_argument, default=False,
                        help='Train feedforward policy')
    parser.add_argument('--single_task_mode', type=boolean_argument, default=False,
                        help='train policy on one (randomly chosen) environment only')

    # --- OTHERS ---

    # logging, saving, evaluation
    parser.add_argument('--log_interval', type=int, default=50, help='log interval, one log per n updates')
    parser.add_argument('--save_interval', type=int, default=200, help='save interval, one save per n updates')
    parser.add_argument('--save_intermediate_models', type=boolean_argument, default=True, help='save all models')
    parser.add_argument('--eval_interval', type=int, default=200, help='eval interval, one eval per n updates')
    parser.add_argument('--vis_interval', type=int, default=500, help='visualisation interval, one eval per n updates')
    parser.add_argument('--results_log_dir', default=None, help='directory to save results (None uses ./logs)')
    parser.add_argument('--render', type=boolean_argument, default=False,
                        help='render during eval')
    parser.add_argument('--parametric_num', type=int, default=50, help='number of parametric variations for evaluation')

    # general settings
    parser.add_argument('--seed',  nargs='+', type=int, default=[20])
    parser.add_argument('--deterministic_execution', type=boolean_argument, default=False,
                        help='Make code fully deterministic. Expects 1 process and uses deterministic CUDNN')

    # --- Virtual ---
    parser.add_argument('--virtual_ratio', type=float, default=0.0, help='virtual training ratio')
    parser.add_argument('--virtual_ratio_increment', type=float, default=0.05, help='virtual ratio increased per 100M steps')
    parser.add_argument('--num_virtual_skills', type=int, default=3)
    parser.add_argument('--include_smaller', type=boolean_argument, default=False,
                        help='allow smaller number of virtual skills')
    parser.add_argument('--virtual_dist', default='rms', help='virtual skill distribution, currently uni, dir, dir-interpolate, and rms')
    parser.add_argument('--virtual_intrinsic', type=float, default=0.0,
                        help='weight for virtual reward')

    # --- Resample ---
    parser.add_argument('--resample_tasks', type=boolean_argument, default=False, help='resample tasks given first state')

    # --- Gradient Correction ---#WORK IN PROGRESS
    parser.add_argument('--grad_correction', default='none', help='gradient correction method: none, pcgrad, cagrad, nash, amtl')
    parser.add_argument('--task_identification', default='none', help='unimodal has always none. mixture gaussian can have argmax, kmeans, sampling possible')

    return parser.parse_args(rest_args) 