######################
# For local launches #
######################
defaults:
  - override hydra/launcher: submitit_slurm

hydra:
  run:
    dir: runs/${env_name}/${seed}/${now:%Y-%m-%d}/${now:%H-%M-%S}/${hydra:job.num}
  
  sweep:
    dir: runs/${env_name}/${seed}/${now:%Y-%m-%d}/${now:%H-%M-%S}
    subdir: ${hydra:job.num}
    
date: ${now:%Y-%m-%d}/${now:%H-%M-%S}
log_dir: ${hydra:run.dir}
workdir: .
logger_verbose: True
enable_tqdm: True

###############
# Main config #
###############
seed: 0
env_name: antmaze-extreme-play-v0
stitching: aug # aug, no_aug

env:
  classname: gym.make
  id: ${env_name}

evaluation:
  env: ${env}
  n_episodes: 50
  reward_variable: reward
  max_db_size: 32
  parallel: false
  serial: false
  bot_args:
    eval: true
    stochastic: false
    add_when_out: false
    replan_when_out: false
    replan_every_n_steps: false # false or int
    add_when_out_goal: false
    replan_when_out_goal: false
    track_token_history: false
    plot_infos: True
    n_plan_steps: null
    argmax: false
    pred_type: 'sample' # sample, k_sample, beam_search
    k_samples: 10 # int
    beam_width: 3 # int
    choice: 'best_score' # best_score, best_length

# Take an episodes_reader to avoid having to create a tmp folder
episodes_reader:
  classname: sgcrl.data.episodes_readers.d4rl.d4rl_episodes_readers.D4RLEpisodesReader
  env_name: ${env_name}

# episodes_reader:
#   classname: sgcrl.data.dbs.on_disk.PytorchOnDiskEpisodesDB
#   env_name: ${env_name}

device: cuda
num_workers: 7

train_quantizer: False
train_transformer: False
train_high_value: True
train_goal_policy: False
train_subgoal_policy: True

###### Quantizer parameters:
batch_size_quantizer: 16384
max_epoch_quantizer: 1000
save_every_quantizer: 250
contrastive_coef: 2e1
commit_coef: 1e3 
reconstruction_coef: 1e5
offset: 0
norm:
  - 90
  - 50
noise: 2.
quantizer_window: 100
random_negative: True
keys_to_tokenize:
  - obs/pos

###### Transformer parameters:
batch_size_transformer: 64
augment_transformer_data_prob: 0.99
remove_cycles: True
validation_split: .95
save_every_transformer: 250
max_epoch_transformer: 250

###### Policy parameters:
policy: 'qgcbc'
low_level_lr: 5e-6
low_level_lr_value: 5e-6
max_gradient_step: 1e6
save_every_low_level_policy: 31250
batch_size_low_level_policy: 1024
expectile: .9
beta: 3.0
discount: .995
hiql_parameters:
  reward_scale: 1.0
  discount: ${discount}
  expectile: ${expectile}
  beta: ${beta}
  clip_score: 100.
  v_update_period: 1
  policy_update_period: 1
  target_update_period: 1
  polyak_coef: 5e-3
  keys_to_tokenize: ${keys_to_tokenize}
  dual_policy_relabellers: ${dual_policy_relabellers}
relabellers:
batch_relabellers:
  - classname: sgcrl.data.torch_datasets.relabel.ObsRelabeller
    obs_keys:
      - obs/pos
      - obs/other
    partial_obs_keys:
      - obs/pos
      - obs/other
  - classname: sgcrl.data.torch_datasets.relabel.ValueBatchRelabeller
    probabilities:
      - .2
      - .8
      - 0.0
    use_obs_representation: True
  - classname: sgcrl.data.torch_datasets.relabel.ObsGoalRelabeller
dual_policy_relabellers:
  - classname: sgcrl.data.torch_datasets.relabel.PosOtherRelabeller
  - classname: sgcrl.data.torch_datasets.relabel.ObsRelabeller
    obs_keys:
      - obs/pos
      - obs/other
    partial_obs_keys:
      - obs/pos
      - obs/other

###### Models
#############
# quantizer #
#############
quantizer:
  classname: sgcrl.models.quantizer.quantizer
  dim: 8
  hidden_dim: 16
  codebook_size: 96

optimizer_quantizer:
  classname: torch.optim.Adam
  lr: 3e-4

###############
# transformer #
###############
transformer:
    classname: sgcrl.models.transformer.Transformer
    max_sequence_length: 128
    embedding_dim: 128
    dim: 128
    num_layers: 4
    nhead: 4
    dropout: 0.2

optimizer_transformer:
    classname: torch.optim.Adam
    lr: 1e-5

##################
# subgoal policy #
##################
low_level_policy_subgoals:
  classname: sgcrl.models.hiql.HIQL_overlay
  vf1:
    classname: sgcrl.models.hiql.GCMLPValue
    input_dim: 31
    embedding_dim: 0
    num_embedding: 0
    hidden_sizes: [512, 512, 512]
    layer_norm: True
  vf2:
    classname: sgcrl.models.hiql.GCMLPValue
    input_dim: 31
    embedding_dim: 0
    num_embedding: 0
    hidden_sizes: [512, 512, 512]
    layer_norm: True
  policy_low:
    classname: sgcrl.models.hiql.MLPGaussianPolicy
    input_dim: 31
    embedding_dim: 0 # 32
    num_embedding: 0
    hidden_sizes: [256, 256]
    action_dim: 8
    min_log_std: 2
    max_log_std: -5
  policy_high:
    classname: sgcrl.models.hiql.MLPGaussianPolicy
    input_dim: 31
    embedding_dim: 0 # 32
    num_embedding: 0
    hidden_sizes: [256, 256]
    action_dim: 8
    min_log_std: 2
    max_log_std: -5

optimizer_low_level_policy_subgoals:
  classname: sgcrl.models.hiql.HIQL_optimizer_overlay
  parameters:
    optimizer_vf1:
      classname: torch.optim.Adam
      lr: ${low_level_lr_value}
    optimizer_vf2:
      classname: torch.optim.Adam
      lr: ${low_level_lr_value}
    optimizer_policy_low:
      classname: torch.optim.Adam
      lr: ${low_level_lr}
    optimizer_policy_high:
      classname: torch.optim.Adam
      lr: ${low_level_lr}
  
###############
# goal policy #
###############
low_level_policy_goals:
  classname: sgcrl.models.hiql.HIQL_overlay
  vf1:
    classname: sgcrl.models.hiql.GCMLPValue
    input_dim: 31
    embedding_dim: 0
    num_embedding: 0
    hidden_sizes: [512, 512, 512]
    layer_norm: True
  vf2:
    classname: sgcrl.models.hiql.GCMLPValue
    input_dim: 31
    embedding_dim: 0
    num_embedding: 0
    hidden_sizes: [512, 512, 512]
    layer_norm: True
  policy_low:
    classname: sgcrl.models.hiql.MLPGaussianPolicy
    input_dim: 31
    embedding_dim: 0 # 32
    num_embedding: 0
    hidden_sizes: [256, 256]
    action_dim: 8
    min_log_std: 2
    max_log_std: -5
  policy_high:
    classname: sgcrl.models.hiql.MLPGaussianPolicy
    input_dim: 31
    embedding_dim: 0 # 32
    num_embedding: 0
    hidden_sizes: [256, 256]
    action_dim: 8
    min_log_std: 2
    max_log_std: -5

optimizer_low_level_policy_goals:
  classname: sgcrl.models.hiql.HIQL_optimizer_overlay
  parameters:
    optimizer_vf1:
      classname: torch.optim.Adam
      lr: ${low_level_lr_value}
    optimizer_vf2:
      classname: torch.optim.Adam
      lr: ${low_level_lr_value}
    optimizer_policy_low:
      classname: torch.optim.Adam
      lr: ${low_level_lr}
    optimizer_policy_high:
      classname: torch.optim.Adam
      lr: ${low_level_lr}
