local_cuda_rank: 0
# because working dictionary is "....../mtm/outputs/2024-03-05/13-19-37"
pretrain_model_path: ../../../outputs/omtm_zero/walker2d-medium-v2.pt
pretrain_args:
  env_name: walker2d-medium-v2 # hopper-medium-v2 walker2d-medium-v2 halfcheetah-medium-replay-v2
  traj_length: 8
goal_mask: "piid" # [piid, id]
waypoints_path: ../../../research/zeroshot_omtm/waypoint_gen/walker-splits.txt

pretrain_dataset:
  _target_: research.omtm.datasets.d4rl_ds.get_datasets
  seq_steps: ???
  env_name: walker2d-medium-v2 # hopper-medium-v2 walker2d-medium-v2
  use_reward: true
  seed: 0
  discount: 1.5
  train_val_split: 0.95

tokenizers:
  states:
    _target_: research.omtm.tokenizers.continuous.ContinuousTokenizer.create
  actions:
    _target_: research.omtm.tokenizers.continuous.ContinuousTokenizer.create
  returns:
    _target_: research.omtm.tokenizers.continuous.ContinuousTokenizer.create
  rewards:
    _target_: research.omtm.tokenizers.continuous.ContinuousTokenizer.create

model_config:
  _target_: research.omtm.models.mtm_model.omtmConfig
  norm: "none"
  n_embd: 512
  n_enc_layer: 2
  n_dec_layer: 1
  n_head: 4
  dropout: 0.1
  loss_keys: null
  latent_dim: null
  init_temperature: 0.1
  target_entropy: -3

finetune_args:
  _target_: research.finetune_omtm.finetune.RunConfig
  discount: 0.99
  seed: 0
  traj_batch_size: 512
  trans_batch_size: 256
  using_online_threshold: 5000
  warmup_steps: 1000000 # 1000000
  num_train_steps: 10000000 # very large number, process actually controlled by explore_steps
  explore_steps: 1000000
  learning_rate: 1e-4
  critic_lr: 3e-4
  v_lr: 3e-4
  weight_decay: 5e-3
  device: cuda
  mtm_iter_per_rollout: 200
  v_iter_per_mtm: 10

  #train
  print_every: 1
  log_every: 100
  eval_every: 1
  save_every: 40000

  # debug
  # print_every: 1
  # log_every: 1
  # eval_every: 1
  # save_every: 1

  #plan
  plan: True
  rtg_percent: 1.0 # [0.8 1.0]
  plan_guidance: "critic_guiding" # [critic_guiding, critic_lambda_guiding, noise_adding]
  lmbda: 0.6
  expectile: 0.7
  horizon: 4
  action_samples: 625
  temperature: 1.0

  index_jump: 40

  #replay buffer
  traj_buffer_size: 1000
  trans_buffer_size: 1000000 # 200000
  buffer_init_ratio: 1.0
  pretrain_discount: 1.5
  clip_min: -1.0
  clip_max: 1.0
  tau: 1e-3
  select_mode: "prob"

  #mask
  mask_ratio: [0.5, 0.6, 0.7, 0.8, 0.85, 0.9, 0.95, 1.0]
  p_weights: [0.2, 0.1, 0.6, 0.1]

  critic_hidden_size: 256

wandb:
  project: omtm_finetune
  entity: f69
  resume: null
  experiment_name: finetune
  group: null

hydra:
  job:
    name: finetune
    chdir: True
