######################
# For local launches #
######################
defaults:
  - override hydra/launcher: submitit_slurm

hydra:
  run:
    dir: runs/${env_name}/${seed}/${now:%Y-%m-%d}/${now:%H-%M-%S}/${hydra:job.num}
  
  sweep:
    dir: runs/${env_name}/${seed}/${now:%Y-%m-%d}/${now:%H-%M-%S}
    subdir: ${hydra:job.num}

date: ${now:%Y-%m-%d}/${now:%H-%M-%S}
log_dir: ${hydra:run.dir}
logger_verbose: True
enable_tqdm: True

###############
# Main config #
###############
seed: 0
env_name: antmaze-ultra-diverse-v0

env:
  classname: gym.make
  id: ${env_name}

evaluation:
  env: ${env}
  n_episodes: 50
  reward_variable: reward
  max_db_size: 32
  parallel: True
  bot_args:
    eval: true
    stochastic: false
    add_when_out: true
    replan_when_out: false
    replan_every_n_steps: false # false or int
    add_when_out_goal: true
    replan_when_out_goal: false
    track_token_history: false
    plot_infos: True

# Take an episodes_reader to avoid having to create a tmp folder
episodes_reader:
  classname: sgcrl.data.episodes_readers.d4rl.d4rl_episodes_readers.D4RLEpisodesReader
  env_name: ${env_name}

# episodes_reader:
#   classname: sgcrl.data.dbs.on_disk.PytorchOnDiskEpisodesDB
#   env_name: ${env_name}

device: cuda
num_workers: 7

train_quantizer: True
train_transformer: True
train_subgoal_policy: True
train_goal_policy: True

model_goals: iql_goals.pickle

###### Quantizer parameters:
batch_size_quantizer: 16384
max_epoch_quantizer: 500
contrastive_coef: 2e1
commit_coef: 1e3 
reconstruction_coef: 1e5
save_every_quantizer: 250
offset: 0
norm:
  - 55
  - 40
noise: 2.
quantizer_window: 100
random_negative: True
keys_to_tokenize:
  - obs/pos

###### Transformer parameters:
batch_size_transformer: 64
augment_transformer_data_prob: .5
remove_cycles: True
stitching: True
validation_split: .95
save_every_transformer: 200
max_epoch_transformer: 2000

###### Policy parameters:
low_level_lr: 5e-6
low_level_lr_value: 5e-6
save_every_low_level_policy: 31250
batch_size_low_level_policy: 1024
max_gradient_step: 1e6
expectile: .9
beta: 3.0
discount: .995
hiql_parameters:
  reward_scale: 1.0
  discount: ${discount}
  expectile: ${expectile}
  beta: ${beta}
  clip_score: 100.
  v_update_period: 1
  policy_update_period: 1
  target_update_period: 1
  polyak_coef: 5e-3
  keys_to_tokenize: ${keys_to_tokenize}
  dual_policy_relabellers: ${dual_policy_relabellers}
relabellers:
batch_relabellers:
  - classname: sgcrl.relabel.ObsRelabeller
    obs_keys:
      - obs/pos
      - obs/other
    partial_obs_keys:
      - obs/pos
      - obs/other
  - classname: sgcrl.relabel.ValueBatchRelabeller
    probabilities:
      - .2
      - .8
      - 0.0
    use_obs_representation: True
  - classname: sgcrl.relabel.ObsGoalRelabeller
dual_policy_relabellers:
  - classname: sgcrl.relabel.PosOtherRelabeller
  - classname: sgcrl.relabel.ObsRelabeller
    obs_keys:
      - obs/pos
      - obs/other
    partial_obs_keys:
      - obs/pos
      - obs/other

###### Models
#############
# quantizer #
#############
quantizer:
  classname: sgcrl.models.quantizer.quantizer
  dim: 8
  hidden_dim: 16
  codebook_size: 48

optimizer_quantizer:
  classname: torch.optim.Adam
  lr: 3e-4

###############
# transformer #
###############
transformer:
    classname: sgcrl.models.transformer.Transformer
    max_sequence_length: 128
    embedding_dim: 256
    dim: 128
    num_layers: 2
    nhead: 4
    dropout: 0.2

optimizer_transformer:
    classname: torch.optim.Adam
    lr: 1e-4

##################
# subgoal policy #
##################
low_level_policy_subgoals:
  classname: sgcrl.models.hiql.HIQL_overlay
  vf1:
    classname: sgcrl.models.hiql.GCMLPValue
    input_dim: 31
    embedding_dim: 0
    num_embedding: 0
    hidden_sizes: [512, 512, 512]
    layer_norm: True
  vf2:
    classname: sgcrl.models.hiql.GCMLPValue
    input_dim: 31
    embedding_dim: 0
    num_embedding: 0
    hidden_sizes: [512, 512, 512]
    layer_norm: True
  policy_low:
    classname: sgcrl.models.hiql.MLPGaussianPolicy
    input_dim: 31
    embedding_dim: 0 # 32
    num_embedding: 0
    hidden_sizes: [256, 256]
    action_dim: 8
    min_log_std: 2
    max_log_std: -5
  policy_high:
    classname: sgcrl.models.hiql.MLPGaussianPolicy
    input_dim: 31
    embedding_dim: 0 # 32
    num_embedding: 0
    hidden_sizes: [256, 256]
    action_dim: 8
    min_log_std: 2
    max_log_std: -5

optimizer_low_level_policy_subgoals:
  classname: sgcrl.models.hiql.HIQL_optimizer_overlay
  parameters:
    optimizer_vf1:
      classname: torch.optim.Adam
      lr: ${low_level_lr_value}
    optimizer_vf2:
      classname: torch.optim.Adam
      lr: ${low_level_lr_value}
    optimizer_policy_low:
      classname: torch.optim.Adam
      lr: ${low_level_lr}
    optimizer_policy_high:
      classname: torch.optim.Adam
      lr: ${low_level_lr}

###############
# goal policy #
###############
low_level_policy_goals:
  classname: sgcrl.models.hiql.HIQL_overlay
  vf1:
    classname: sgcrl.models.hiql.GCMLPValue
    input_dim: 31
    embedding_dim: 0
    num_embedding: 0
    hidden_sizes: [512, 512, 512]
    layer_norm: True
  vf2:
    classname: sgcrl.models.hiql.GCMLPValue
    input_dim: 31
    embedding_dim: 0
    num_embedding: 0
    hidden_sizes: [512, 512, 512]
    layer_norm: True
  policy_low:
    classname: sgcrl.models.hiql.MLPGaussianPolicy
    input_dim: 31
    embedding_dim: 0 # 32
    num_embedding: 0
    hidden_sizes: [256, 256]
    action_dim: 8
    min_log_std: 2
    max_log_std: -5
  policy_high:
    classname: sgcrl.models.hiql.MLPGaussianPolicy
    input_dim: 31
    embedding_dim: 0 # 32
    num_embedding: 0
    hidden_sizes: [256, 256]
    action_dim: 8
    min_log_std: 2
    max_log_std: -5

optimizer_low_level_policy_goals:
  classname: sgcrl.models.hiql.HIQL_optimizer_overlay
  parameters:
    optimizer_vf1:
      classname: torch.optim.Adam
      lr: ${low_level_lr_value}
    optimizer_vf2:
      classname: torch.optim.Adam
      lr: ${low_level_lr_value}
    optimizer_policy_low:
      classname: torch.optim.Adam
      lr: ${low_level_lr}
    optimizer_policy_high:
      classname: torch.optim.Adam
      lr: ${low_level_lr}
