######################
# For local launches #
######################
defaults:
  - override hydra/launcher: submitit_slurm

hydra:
  run:
    dir: runs/${env_name}/${algorithm_name}/${seed}/${now:%Y-%m-%d}/${now:%H-%M-%S}/${hydra:job.num}
  
  sweep:
    dir: runs/${env_name}/${algorithm_name}/${seed}/${now:%Y-%m-%d}/${now:%H-%M-%S}
    subdir: ${hydra:job.num}
    
log_dir: ${hydra:run.dir}
logger_verbose: True
enable_tqdm: True

###############
# Main config #
###############
seed: 0
env_name: antmaze-large-diverse-v2
algorithm_name: qphil

env:
  classname: gym.make
  id: ${env_name}

evaluation:
  env: ${env}
  n_episodes: 50
  reward_variable: reward
  max_db_size: 32
  parallel: False
  serial: True
  bot_args:
    eval: true
    stochastic: false

# Take an episodes_reader to avoid having to create a tmp folder
episodes_reader:
  classname: sgcrl.data.episodes_readers.d4rl.d4rl_episodes_readers.D4RLEpisodesReader
  env_name: ${env_name}

# episodes_reader:
#   classname: sgcrl.data.dbs.on_disk.PytorchOnDiskEpisodesDB
#   env_name: ${env_name}

device: cuda
num_workers: 7

train_quantizer: True
train_transformer: True
train_goal_policy: False
train_subgoal_policy: True

###### Quantizer parameters:
batch_size_quantizer: 16384
max_epoch_quantizer: 500
save_every_quantizer: 250
contrastive_coef: 2e1
commit_coef: 1e3 
reconstruction_coef: 1e5
offset: 0
norm:
  - 40
  - 30
noise: 2.
quantizer_window: 100
random_negative: True
keys_to_tokenize:
  - obs/pos

###### Transformer parameters:
save_every_transformer: 250
max_epoch_transformer: 250
batch_size_transformer: 64
augment_transformer_data_prob: .5
remove_cycles: True
stitching: True
validation_split: .95

###### Policy parameters:
goal_model_name: 'iql_goals.pickle'
low_level_lr: 5e-6
low_level_lr_value: 5e-6
save_every_low_level_policy: 31250
max_gradient_step: 1e6
batch_size_low_level_policy: 1024
expectile: 0.7
beta: 1.0
discount: .99
use_obs_representation: True
hiql_parameters:
  reward_scale: 1.0
  discount: ${discount}
  expectile: ${expectile}
  beta: ${beta}
  clip_score: 100.
  v_update_period: 1
  policy_update_period: 1
  target_update_period: 1
  polyak_coef: 5e-3
  keys_to_tokenize: ${keys_to_tokenize}
  dual_policy_relabellers: ${dual_policy_relabellers}
relabellers:
batch_relabellers:
  - classname: sgcrl.data.torch_datasets.relabel.ObsRelabeller
    obs_keys:
      - obs/pos
      - obs/other
    partial_obs_keys:
      - obs/pos
      - obs/other
  - classname: sgcrl.data.torch_datasets.relabel.ValueBatchRelabeller
    probabilities:
      - .2
      - .8
      - 0.0
    use_obs_representation: ${use_obs_representation}
  - classname: sgcrl.data.torch_datasets.relabel.ObsGoalRelabeller
dual_policy_relabellers:
  - classname: sgcrl.data.torch_datasets.relabel.PosOtherRelabeller
  - classname: sgcrl.data.torch_datasets.relabel.ObsRelabeller
    obs_keys:
      - obs/pos
      - obs/other
    partial_obs_keys:
      - obs/pos
      - obs/other

###### Models
#############
# quantizer #
#############
quantizer:
  classname: sgcrl.models.quantizer.quantizer
  dim: 8
  hidden_dim: 16
  codebook_size: 32

optimizer_quantizer:
  classname: torch.optim.Adam
  lr: 3e-4

###############
# transformer #
###############
transformer:
    classname: sgcrl.models.transformer.Transformer
    max_sequence_length: 128
    embedding_dim: 128
    dim: 128
    num_layers: 4
    nhead: 4
    dropout: 0.2

optimizer_transformer:
    classname: torch.optim.Adam
    lr: 1e-5

##################
# subgoal policy #
##################
low_level_policy_subgoals:
  classname: sgcrl.models.hiql.HIQL_overlay
  vf1:
    classname: sgcrl.models.hiql.GCMLPValue
    input_dim: 31
    embedding_dim: 0
    num_embedding: 0
    hidden_sizes: [512, 512, 512]
    layer_norm: True
  vf2:
    classname: sgcrl.models.hiql.GCMLPValue
    input_dim: 31
    embedding_dim: 0
    num_embedding: 0
    hidden_sizes: [512, 512, 512]
    layer_norm: True
  policy_low:
    classname: sgcrl.models.hiql.MLPGaussianPolicy
    input_dim: 31
    embedding_dim: 0 # 32
    num_embedding: 0
    hidden_sizes: [256, 256]
    action_dim: 8
    min_log_std: 2
    max_log_std: -5
  policy_high:
    classname: sgcrl.models.hiql.MLPGaussianPolicy
    input_dim: 31
    embedding_dim: 0 # 32
    num_embedding: 0
    hidden_sizes: [256, 256]
    action_dim: 8
    min_log_std: 2
    max_log_std: -5

optimizer_low_level_policy_subgoals:
  classname: sgcrl.models.hiql.HIQL_optimizer_overlay
  parameters:
    optimizer_vf1:
      classname: torch.optim.Adam
      lr: ${low_level_lr_value}
    optimizer_vf2:
      classname: torch.optim.Adam
      lr: ${low_level_lr_value}
    optimizer_policy_low:
      classname: torch.optim.Adam
      lr: ${low_level_lr}
    optimizer_policy_high:
      classname: torch.optim.Adam
      lr: ${low_level_lr}

###############
# goal policy #
###############
low_level_policy_goals:
  classname: sgcrl.models.hiql.HIQL_overlay
  vf1:
    classname: sgcrl.models.hiql.GCMLPValue
    input_dim: 31
    embedding_dim: 0
    num_embedding: 0
    hidden_sizes: [512, 512, 512]
    layer_norm: True
  vf2:
    classname: sgcrl.models.hiql.GCMLPValue
    input_dim: 31
    embedding_dim: 0
    num_embedding: 0
    hidden_sizes: [512, 512, 512]
    layer_norm: True
  policy_low:
    classname: sgcrl.models.hiql.MLPGaussianPolicy
    input_dim: 31
    embedding_dim: 0 # 32
    num_embedding: 0
    hidden_sizes: [256, 256]
    action_dim: 8
    min_log_std: 2
    max_log_std: -5
  policy_high:
    classname: sgcrl.models.hiql.MLPGaussianPolicy
    input_dim: 31
    embedding_dim: 0 # 32
    num_embedding: 0
    hidden_sizes: [256, 256]
    action_dim: 8
    min_log_std: 2
    max_log_std: -5

optimizer_low_level_policy_goals:
  classname: sgcrl.models.hiql.HIQL_optimizer_overlay
  parameters:
    optimizer_vf1:
      classname: torch.optim.Adam
      lr: ${low_level_lr_value}
    optimizer_vf2:
      classname: torch.optim.Adam
      lr: ${low_level_lr_value}
    optimizer_policy_low:
      classname: torch.optim.Adam
      lr: ${low_level_lr}
    optimizer_policy_high:
      classname: torch.optim.Adam
      lr: ${low_level_lr}
