import argparse
from utils.helpers import boolean_argument

def get_args(rest_args):
    parser = argparse.ArgumentParser()

    # --- GENERAL ---

    # training parameters - FAST TEST VALUES
    parser.add_argument('--num_frames', type=int, default=50000, help='number of frames to train (TEST: reduced from 40e7)')
    parser.add_argument('--max_rollouts_per_task', type=int, default=2, help='TEST: reduced from 10')
    parser.add_argument('--exp_label', default='SDVT_GMVAE_test', help='label (typically name of method)')
    parser.add_argument('--env_name', default='ML45Env-v2', help='environment to train on')

    # --- POLICY ---

    #using separate GRU
    parser.add_argument('--policy_separate_gru', type=boolean_argument, default=False, help='condition policy on state')

    # what to pass to the policy (note this is after the encoder)
    parser.add_argument('--pass_state_to_policy', type=boolean_argument, default=True, help='condition policy on state')
    parser.add_argument('--pass_latent_to_policy', type=boolean_argument, default=True, help='condition policy on VAE latent')
    parser.add_argument('--pass_belief_to_policy', type=boolean_argument, default=False, help='condition policy on ground-truth belief')
    parser.add_argument('--pass_task_to_policy', type=boolean_argument, default=False, help='condition policy on ground-truth task description')
    parser.add_argument('--pass_prob_to_policy', type=boolean_argument, default=True, help='condition policy on mixture VAEs probs, only when VAE mixture num >1')

    # using separate encoders for the different inputs ("None" uses no encoder)
    parser.add_argument('--policy_state_embedding_dim', type=int, default=32, help='TEST: reduced from 64')
    parser.add_argument('--policy_latent_embedding_dim', type=int, default=32, help='TEST: reduced from 64')
    parser.add_argument('--policy_belief_embedding_dim', type=int, default=32, help='TEST: reduced from 64')
    parser.add_argument('--policy_task_embedding_dim', type=int, default=None)
    parser.add_argument('--policy_prob_embedding_dim', type=int, default=32, help='TEST: reduced from 64')

    # normalising (inputs/rewards/outputs)
    parser.add_argument('--norm_state_for_policy', type=boolean_argument, default=True, help='normalise state input')
    parser.add_argument('--norm_latent_for_policy', type=boolean_argument, default=True, help='normalise latent input')
    parser.add_argument('--norm_belief_for_policy', type=boolean_argument, default=True, help='normalise belief input')
    parser.add_argument('--norm_task_for_policy', type=boolean_argument, default=True, help='normalise task input')
    parser.add_argument('--norm_prob_for_policy', type=boolean_argument, default=True, help='normalise prob input, only when VAE mixture num >1')
    parser.add_argument('--norm_rew_for_policy', type=boolean_argument, default=True, help='normalise rew for RL train')
    parser.add_argument('--norm_actions_pre_sampling', type=boolean_argument, default=False, help='normalise policy output')
    parser.add_argument('--norm_actions_post_sampling', type=boolean_argument, default=False, help='normalise policy output')

    # network - SMALLER FOR TESTING
    parser.add_argument('--policy_layers', nargs='+', default=[128, 128], help='TEST: reduced from [256, 256]')
    parser.add_argument('--policy_activation_function', type=str, default='tanh', help='tanh/relu/leaky-relu')
    parser.add_argument('--policy_initialisation', type=str, default='normc', help='normc/orthogonal')
    parser.add_argument('--policy_anneal_lr', type=boolean_argument, default=False, help='anneal LR over time')

    # RL algorithm
    parser.add_argument('--policy', type=str, default='ppo', help='choose: a2c, ppo')
    parser.add_argument('--policy_optimiser', type=str, default='adam', help='choose: rmsprop, adam')

    # PPO specific - REDUCED FOR TESTING
    parser.add_argument('--ppo_num_epochs', type=int, default=2, help='TEST: reduced from 5')
    parser.add_argument('--ppo_num_minibatch', type=int, default=4, help='TEST: reduced from 10')
    parser.add_argument('--ppo_use_huberloss', type=boolean_argument, default=True, help='use huberloss instead of MSE')
    parser.add_argument('--ppo_use_clipped_value_loss', type=boolean_argument, default=True, help='clip value loss')
    parser.add_argument('--ppo_clip_param', type=float, default=0.1, help='clamp param')
    parser.add_argument('--ppo_disc',  type=boolean_argument, default=False, help='dimension-wise clipping')

    # Additional policy attributes expected by metalearner
    parser.add_argument('--policy_use_huber_loss', type=boolean_argument, default=True, help='use huber loss in policy (alias for ppo_use_huberloss)')
    parser.add_argument('--policy_use_clipped_value_loss', type=boolean_argument, default=True, help='clip value loss in policy (alias for ppo_use_clipped_value_loss)')

    # other hyperparameters - FAST TEST VALUES
    parser.add_argument('--lr_policy', type=float, default=7e-4, help='learning rate (default: 7e-4)')
    parser.add_argument('--num_processes', type=int, default=4, help='TEST: reduced from 10')
    parser.add_argument('--policy_num_steps', type=int, default=1000, help='TEST: reduced from 5000')
    parser.add_argument('--policy_eps', type=float, default=1e-8, help='optimizer epsilon (1e-8 for ppo, 1e-5 for a2c)')
    parser.add_argument('--policy_init_std', type=float, default=1.0, help='only used for continuous actions')
    parser.add_argument('--policy_min_std', type=float, default=0.5, help='minimum std of policy only used for continuous actions')
    parser.add_argument('--policy_max_std', type=float, default=1.5, help='maximum std of policy only used for continuous actions')
    parser.add_argument('--policy_value_loss_coef', type=float, default=0.5, help='value loss coefficient')
    parser.add_argument('--policy_entropy_coef', type=float, default=0.001, help='entropy term coefficient')
    parser.add_argument('--policy_gamma', type=float, default=0.99, help='discount factor for rewards')
    parser.add_argument('--policy_use_gae', type=boolean_argument, default=True,
                        help='use generalized advantage estimation')
    parser.add_argument('--policy_tau', type=float, default=0.90, help='gae parameter')
    parser.add_argument('--use_proper_time_limits', type=boolean_argument, default=True,
                        help='treat timeout and death differently (important in mujoco)')
    parser.add_argument('--policy_max_grad_norm', type=float, default=0.5, help='max norm of gradients')
    parser.add_argument('--encoder_max_grad_norm', type=float, default=1.0, help='max norm of gradients')
    parser.add_argument('--decoder_max_grad_norm', type=float, default=1.0, help='max norm of gradients')

    # --- VAE TRAINING --- FAST TEST VALUES
    parser.add_argument('--dropout_rate', type=float, default=0.0, help='TEST: disabled dropout for speed')
    # general
    parser.add_argument('--lr_vae', type=float, default=0.001)
    parser.add_argument('--size_vae_buffer', type=int, default=100, help='TEST: reduced from 1000')
    parser.add_argument('--precollect_len', type=int, default=1000, help='TEST: reduced from 5000')
    parser.add_argument('--vae_buffer_add_thresh', type=float, default=1,
                        help='probability of adding a new trajectory to buffer')
    parser.add_argument('--vae_batch_num_trajs', type=int, default=4, help='TEST: reduced from 10')
    parser.add_argument('--tbptt_stepsize', type=int, default=20, help='TEST: reduced from 50')
    parser.add_argument('--vae_subsample_elbos', type=int, default=20, help='TEST: reduced from 100')
    parser.add_argument('--vae_subsample_decodes', type=int, default=20, help='TEST: reduced from 100')
    parser.add_argument('--vae_avg_elbo_terms', type=boolean_argument, default=True,
                        help='Average ELBO terms (instead of sum)')
    parser.add_argument('--vae_avg_reconstruction_terms', type=boolean_argument, default=True,
                        help='Average reconstruction terms (instead of sum)')
    parser.add_argument('--num_vae_updates', type=int, default=2, help='TEST: reduced from 20')
    parser.add_argument('--pretrain_len', type=int, default=0, help='for how many updates to pre-train the VAE')
    parser.add_argument('--kl_weight', type=float, default=0.1, help='weight for the KL term')

    parser.add_argument('--split_batches_by_task', type=boolean_argument, default=False,
                        help='split batches up by task (to save memory or if tasks are of different length)')
    parser.add_argument('--split_batches_by_elbo', type=boolean_argument, default=False,
                        help='split batches up by elbo term (to save memory of if ELBOs are of different length)')

    # --- GMVAE SPECIFIC ---
    # GMVAE latent structure: w (style), z (skill), y (analytic mixture weights)
    parser.add_argument('--vae_mixture_num', type=int, default=5, help='TEST: reduced from 12')
    
    # GMVAE loss coefficients (separate control for each KL term)
    parser.add_argument('--kl_w_loss_coeff', type=float, default=1.0,
                        help='coefficient for KL(q(w|h)||p(w)) loss term')
    parser.add_argument('--kl_z_loss_coeff', type=float, default=1.0,
                        help='coefficient for KL(q(z|h)||p(z|w,y)) loss term')
    parser.add_argument('--kl_y_loss_coeff', type=float, default=1.0,
                        help='coefficient for categorical y loss term')
    
    # Policy integration
    parser.add_argument('--pass_w_to_policy', type=boolean_argument, default=False,
                        help='condition policy on w (style latent)')
    parser.add_argument('--policy_w_embedding_dim', type=int, default=32, help='TEST: reduced from 64')
    parser.add_argument('--norm_w_for_policy', type=boolean_argument, default=True, 
                        help='normalise w input for policy')
    
    # Legacy/backward compatibility (deprecated for GMVAE)
    parser.add_argument('--gauss_loss_coeff', type=float, default=1.0,
                        help='[DEPRECATED for GMVAE] use kl_w_loss_coeff and kl_z_loss_coeff instead')
    parser.add_argument('--cat_loss_coeff', type=float, default=1.0,
                        help='[DEPRECATED for GMVAE] use kl_y_loss_coeff instead')
    
    # Other mixture settings
    parser.add_argument('--gumbel_temperature', type=float, default=1.0,
                        help='Gumbel softmax temperature, when nearly 0, hardmax')
    parser.add_argument('--occ_loss_coeff', type=float, default=0.0, help='TEST: disabled for speed')
    parser.add_argument('--occ_loss_type', type=str, default='exp',
                        help='choose: linear, log, exp')

    # - encoder - SMALLER FOR TESTING
    parser.add_argument('--action_embedding_size', type=int, default=8, help='TEST: reduced from 16')
    parser.add_argument('--state_embedding_size', type=int, default=16, help='TEST: reduced from 32')
    parser.add_argument('--reward_embedding_size', type=int, default=8, help='TEST: reduced from 16')
    parser.add_argument('--encoder_layers_before_gru', nargs='+', type=int, default=[])
    parser.add_argument('--encoder_gru_hidden_size', type=int, default=128, help='TEST: reduced from 256')
    parser.add_argument('--encoder_layers_after_gru', nargs='+', type=int, default=[])
    parser.add_argument('--latent_dim', type=int, default=5, help='TEST: reduced from 10')

    # --- encoder: RNN type ---
    parser.add_argument('--rnn_type', default='gru', help='gru or block-rnn')

    # - decoder: rewards - SMALLER FOR TESTING
    parser.add_argument('--decode_reward', type=boolean_argument, default=True, help='use reward decoder')
    parser.add_argument('--normalise_rew_targets', type=boolean_argument, default=False, help='divide reward targets by largest rew seen')
    parser.add_argument('--rew_loss_coeff', type=float, default=10, help='weight for state loss (vs reward loss)')
    parser.add_argument('--input_prev_state', type=boolean_argument, default=True, help='use prev state for rew pred')
    parser.add_argument('--input_action', type=boolean_argument, default=True, help='use prev action for rew pred')
    parser.add_argument('--reward_decoder_layers', nargs='+', type=int, default=[32, 32], help='TEST: reduced from [64, 64, 32]')
    parser.add_argument('--multihead_for_reward', type=boolean_argument, default=False,
                        help='one head per reward pred (i.e. per state)')
    parser.add_argument('--rew_pred_type', type=str, default='deterministic',
                        help='choose: '
                             'bernoulli (predict p(r=1|s))'
                             'categorical (predict p(r=1|s) but use softmax instead of sigmoid)'
                             'deterministic (treat as regression problem)')

    # - decoder: state transitions - SMALLER FOR TESTING
    parser.add_argument('--decode_state', type=boolean_argument, default=True, help='use state decoder')
    parser.add_argument('--state_loss_coeff', type=float, default=100, help='TEST: reduced from 1000')
    parser.add_argument('--state_decoder_layers', nargs='+', type=int, default=[32, 32], help='TEST: reduced from [64, 64, 32]')
    parser.add_argument('--state_pred_type', type=str, default='deterministic', help='choose: deterministic, gaussian')

    # - decoder: ground-truth task ("varibad oracle", after Humplik et al. 2019)
    parser.add_argument('--decode_task', type=boolean_argument, default=False, help='use task decoder')
    parser.add_argument('--task_loss_coeff', type=float, default=1.0, help='weight for task loss')
    parser.add_argument('--task_decoder_layers', nargs='+', type=int, default=[32, 16], help='TEST: reduced from [64, 32]')
    parser.add_argument('--task_pred_type', type=str, default='task_id', help='choose: task_id, task_description')

    # --- ABLATIONS ---

    # for the VAE
    parser.add_argument('--disable_decoder', type=boolean_argument, default=False,
                        help='train without decoder')
    parser.add_argument('--disable_stochasticity_in_latent', type=boolean_argument, default=False,
                        help='use auto-encoder (non-variational)')
    parser.add_argument('--disable_kl_term', type=boolean_argument, default=False,
                        help='dont use the KL regularising loss term')
    parser.add_argument('--decode_only_past', type=boolean_argument, default=False,
                        help='only decoder past observations, not the future')
    parser.add_argument('--kl_to_gauss_prior', type=boolean_argument, default=False,
                        help='KL term in ELBO to fixed Gaussian prior (instead of prev approx posterior)')

    # --- Extrapolate ----
    parser.add_argument('--vae_extrapolate', type=boolean_argument, default=False,
                        help='TEST: disabled for speed')
    parser.add_argument('--ext_loss_coeff', type=float, default=0.0,
                        help='TEST: disabled for speed')

    # combining vae and RL loss
    parser.add_argument('--rlloss_through_encoder', type=boolean_argument, default=False,
                        help='backprop rl loss through encoder')
    parser.add_argument('--add_nonlinearity_to_latent', type=boolean_argument, default=False,
                        help='Use relu before feeding latent to policy')
    parser.add_argument('--vae_loss_coeff', type=float, default=1.0,
                        help='weight for VAE loss (vs RL loss)')

    # for the policy training
    parser.add_argument('--sample_embeddings', type=boolean_argument, default=False,
                        help='sample embedding for policy, instead of full belief')

    # for other things
    parser.add_argument('--disable_metalearner', type=boolean_argument, default=False,
                        help='Train feedforward policy')
    parser.add_argument('--single_task_mode', type=boolean_argument, default=False,
                        help='train policy on one (randomly chosen) environment only')

    # --- RESAMPLE ---
    parser.add_argument('--resample_tasks', type=boolean_argument, default=False, help='resample tasks given first state')

    # --- GRADIENT CORRECTION ---
    parser.add_argument('--grad_correction', default='none', help='gradient correction method: none, pcgrad, cagrad, nash, amtl')
    parser.add_argument('--task_identification', default='none', help='unimodal has always none. mixture gaussian can have argmax, kmeans, sampling possible')

    # --- VIRTUAL TRAINING ---
    parser.add_argument('--virtual_ratio', type=float, default=0.0, help='TEST: disabled virtual training for speed')
    parser.add_argument('--virtual_ratio_increment', type=float, default=0.0)
    parser.add_argument('--num_virtual_skills', type=int, default=1)
    parser.add_argument('--include_smaller', type=boolean_argument, default=False)
    parser.add_argument('--virtual_dist', type=str, default='dir')
    parser.add_argument('--virtual_intrinsic', type=float, default=0.0,
                        help='weight for virtual reward')

    # --- PRETRAINING ---
    parser.add_argument('--pretrainer', type=str, default=None, help='pretrainer type')
    parser.add_argument('--pretrain_env_name', type=str, default=None, help='pretraining environment name')
    parser.add_argument('--pretrain_exp_name', type=str, default=None, help='pretraining experiment name')
    parser.add_argument('--pretrain_seed', type=int, default=None, help='pretraining seed')
    parser.add_argument('--pretrain_frames', type=int, default=None, help='pretraining frames')
    parser.add_argument('--num_vae_updates_per_pretrain', type=int, default=3, help='VAE updates per pretrain step')

    # --- LOGGING ---
    parser.add_argument('--log_interval', type=int, default=1, help='log interval, one log per n updates')
    parser.add_argument('--save_interval', type=int, default=2, help='TEST: reduced from 500')
    parser.add_argument('--eval_interval', type=int, default=2, help='TEST: reduced from 250')
    parser.add_argument('--vis_interval', type=int, default=2, help='visualisation interval for latent space')
    parser.add_argument('--save_intermediate_models', type=boolean_argument, default=False,
                        help='save models at every save_interval (vs only at the end)')
    parser.add_argument('--render', type=boolean_argument, default=False,
                        help='render during eval')
    parser.add_argument('--parametric_num', type=int, default=10, help='TEST: reduced from 50')

    # --- GENERAL ---
    parser.add_argument('--seed', type=int, nargs='+', default=[73], help='random seed')
    parser.add_argument('--deterministic_execution', type=boolean_argument, default=False,
                        help='Make code fully deterministic. Expects 1 process and uses deterministic CUDNN')
    parser.add_argument('--results_log_dir', default='./logs', help='directory to save agent logs (default: ./logs)')

    args = parser.parse_args(rest_args)
    args.cuda = True
    return args 