# default FOCAL experiment settings
# all experiments should modify these settings only as needed
default_config = dict(
    env_name='cheetah-dir',
    n_train_tasks=2,
    n_eval_tasks=2,
    latent_size=20, # dimension of the latent context vector
    net_size=256, # number of units per FC layer in each network
    path_to_weights=None, # path to pre-trained weights to load into networks
    seed_list=[0], # list of random seeds
    env_params=dict(
        n_tasks=2, # number of distinct tasks in this domain, should equal sum of train and eval tasks
        randomize_tasks=True, # shuffle the tasks after creating them
        max_episode_steps=200, # built-in max episode length for this environment
    ),
    algo_params=dict(
        meta_batch=16, # number of tasks to average the gradient across
        num_iterations=500, # number of data sampling / training iterates
        num_initial_steps=2000, # number of transitions collected per task before training
        num_tasks_sample=5, # number of randomly sampled tasks to collect data for each iteration
        num_steps_prior=400, # number of transitions to collect per task with z ~ prior
        num_steps_posterior=0, # number of transitions to collect per task with z ~ posterior
        num_extra_rl_steps_posterior=400, # number of additional transitions to collect per task with z ~ posterior that are only used to train the policy and NOT the encoder
        num_train_steps_per_itr=2000, # number of meta-gradient steps taken per iteration
        num_evals=2, # number of independent evals
        num_steps_per_eval=600,  # number of transitions to eval on
        batch_size=256, # number of transitions in the RL batch
        embedding_batch_size=64, # number of transitions in the context batch
        embedding_mini_batch_size=64, # number of context transitions to backprop through (should equal the arg above except in the recurrent encoder case)
        max_path_length=200, # max path length for this environment
        discount=0.99, # RL discount factor
        soft_target_tau=0.005, # for SAC target network update
        policy_lr=3e-4,
        qf_lr=3e-4,
        vf_lr=3e-4,
        context_lr=3e-4,
        c_lr=1e-4, # dual critic learning rate (BRAC dual)
        alpha_lr=1, # alpha learning rate (BRAC)
        c_iter=3, # number of dual critic steps per iteration
        reward_scale=5., # scale rewards before constructing Bellman update, effectively controls weight on the entropy of the policy
        sparse_rewards=False, # whether to sparsify rewards as determined in env
        kl_lambda=.1, # weight on KL divergence term in encoder loss
        use_information_bottleneck=False, # False makes latent context deterministic
        update_post_train=1, # how often to resample the context when collecting data during training (in trajectories)
        num_exp_traj_eval=1, # how many exploration trajs to collect before beginning posterior sampling at test time
        recurrent=False, # recurrent or permutation-invariant encoder
        dump_eval_paths=False, # whether to save evaluation trajectories
        sample=1, # whether to train with stochastic (noise-sampled) trajectories, for offline method (FOCAL) only
        train_epoch=6e5, # corresponding epoch of the model used to generate meta-training trajectories, offline method (FOCAL) only
        eval_epoch=6e5, # corresponding epoch of the model used to generate meta-testing trajectories, offline method (FOCAL) only
        divergence_name='kl', # divergence type in BRAC algo, offline method (FOCAL) only
        use_brac=True, # whether to use BRAC regularization (compare with batch PEARL)
        value_penalty=False, # False if only regulaize policy
        
        train_alpha=True, # whether to train alpha (BRAC)
        alpha_init=500., # Initialized value for alpha (BRAC)
        alpha_max=2000., # Maximum value for alpha
        target_divergence=0.05, # For training alpha adaptively. As in BEAR, if train_alpha=True, increase alpha when div > target_divergence, lower alpha when div < target_divergence (BRAC)
        max_entropy=True, # whether to include max-entropy term (as in SAC and PEARL) in value function
        
        use_transformer_qvp=False, # deprecated whether to use attention in qvp netowrks
        
        use_transformer_sequence=True, # whether to use sequence attention for context encoder
        use_multihead_attention=True, # when using sequence attention, whether to use multihead attention or self attention
        use_channel_attention=True, # whether to use Batch(channel) attention
        attention_mode='parallel', # "parallel", "serialize", "gate", only valid when sequence and batch attention are both activated

        #use_transformer_batch=True, # deprecated
        #use_conv_transformer=False, # deprecated

        use_next_obs_in_context=False, # use next obs if it is useful in distinguishing tasks
        transformer_hidden_size=128,
        n_multihead=8, # transformer_hidden_size % n_multihead == 0
        latent_policy=8, #deprecated policy latent dimension
        allow_logging=True, #deprecated
        allow_backward_z=False, # whether to allow gradients to flow back through z
        allow_eval=True, # if it is True, enable evaluation
        T=0.5, # contrastive learning scale
        m=0.9, # contrastive learning soft update momentum
        agent_type='Contrastive', # deprecated whether to use contrastive agent
        contrastive=True, # use contrastive loss if it is true, else z_loss
        contrastive_encoder_type='soft', # contrastive encoder update type: soft or hard
        mb_replace=False, # meta batch sampling, replace or not
        dropout=0.1, # dropout for context encoder
        use_qvp_dropout=False, # whether to include dropout in qvp networks
        use_qvp_layerNorm=False, # whether to use layNorm in qvp networks
        qvp_network_type='encoder_decoder', # TODO qvp network type
        #model_test=False,
        #pretrained_model_basedir="",
    ),
    util_params=dict(
        base_log_dir='./log',
        use_gpu=True,
        gpu_id=0,
        debug=False, # debugging triggers printing and writes logs to debug directory
        docker=False, # TODO docker is not yet supported
        machine='gpu'
    ),
    algo_type='FOCAL', # FOCAL or PEARL
)



