defaults:
  - _self_

policy:
  _target_: ript.algos.rl_optimizers.openvla_oft_interface.OpenVLA_OFT_Policy
  pretrained_checkpoint: ${algo.checkpoint_path}
  header_checkpoint: ${algo.header_checkpoint}
  task_suite_name: ${algo.task_suite_name}
  lora_rank: ${algo.lora_rank}
  lora_dropout: ${algo.lora_dropout}
  lora_adaptor_ckpt: ${algo.lora_adaptor_ckpt}
  seed: ${algo.model_seed}
  fix_scale_head: ${algo.fix_scale_head}
  log_scale_clip: ${algo.log_scale_clip}

optimizer_factory:
  _target_: torch.optim.AdamW
  _partial_: true
  lr: ${algo.lr}
  betas: [0.9, 0.999]    
  weight_decay: ${algo.weight_decay}

scheduler_factory:
  _target_: torch.optim.lr_scheduler.CosineAnnealingLR
  _partial_: true
  eta_min: 1e-5
  last_epoch: -1
  T_max: ${training.n_epochs}

env_runner:
  _target_: ript.env_runner.openvla_oft_libero_runner.OpenVLAOFTLiberoRunner
  benchmark_name: ${task.benchmark_name}
  rollouts_per_env: ${algo.rollouts_per_env}
  num_parallel_envs: ${algo.num_parallel_envs}
  max_episode_length: ${algo.max_episode_length}
  task_names_to_use: ${task.task_names_to_use}
  use_laplace_sampling: ${algo.use_laplace_sampling}
  scale_factor: ${algo.scale_factor}

rollout_generator_factory:
  _target_: ript.algos.rl_optimizers.rollout_generator.RolloutGenerator
  _partial_: true
  rloo_batch_size: ${algo.rloo_batch_size}
  demo_batch_size: ${train_dataloader.batch_size}
  use_tqdm: true
  early_stop_percentage: ${algo.early_stop_percentage}
  enable_dynamic_sampling: ${algo.enable_dynamic_sampling}
  use_val_init: ${algo.use_val_init}
  mix_val_init_in_rloo: ${algo.mix_val_init_in_rloo}
  rollout_stats_path: ${algo.rollout_stats_path}

rl_optimizer_factory:
  _target_: ript.algos.rl_optimizers.rl_optimizer_openvla_oft.RLOptimizerOpenVLAOFT
  _partial_: true
  gradient_accumulation_steps: ${algo.gradient_accumulation_steps}
  ppo_clip_range: ${algo.ppo_clip_range}
  ppo_clip_high: ${algo.ppo_clip_high}
  num_ppo_epochs: ${algo.num_ppo_epochs}
  tokens_per_step: 8
  max_step_batch_size: ${algo.max_step_batch_size}
  rloo_over_all_rollouts: ${algo.rloo_over_all_rollouts}
  log_prob_mode: ${algo.log_prob_mode}
  grad_norm_clip_model: ${algo.grad_norm_clip_model}
  grad_norm_clip_header: ${algo.grad_norm_clip_header}


shape_meta: ${task.shape_meta}
obs_reduction: 'none'
device: ${device}

name: openvla

# openvla parameters
lora_rank: 32
lora_dropout: 0.0
checkpoint_path: null
header_checkpoint: null
lora_adaptor_ckpt: null
task_suite_name: ${task.benchmark_name}


# training parameters
lr: 1e-4
header_lr: 5e-5
weight_decay: 0.0001
grad_norm_clip_model: 1.0
grad_norm_clip_header: 1.0
# env runner parameters
num_parallel_envs: 4
rollouts_per_env: 16
max_episode_length: null
use_laplace_sampling: true
scale_factor: 1.0
# rollout generator parameters
rloo_batch_size: 4 # number of rollouts per env
early_stop_percentage: 1.0
enable_dynamic_sampling: false
use_val_init: false
mix_val_init_in_rloo: false
log_scale_clip: null

# ppo optimizer parameters
gradient_accumulation_steps: 1
num_ppo_epochs: 1
ppo_clip_range: 0.2
ppo_clip_high: 0.2
max_step_batch_size: 2
rloo_over_all_rollouts: false
log_prob_mode: 'sum_on_action_dim'
# eval only
eval_only: false

rollout_stats_path: null

model_seed: 7
fix_scale_head: false
frame_stack: 1

rollout_training_task_names: null

dataset:
  seq_len: 600
  frame_stack: 1
  obs_seq_len: 1
  lowdim_obs_seq_len: null # this denotes future timestep lowdim observations
  load_obs_for_pretrain: false # since autoencoder training stage does not require obs
  load_next_obs: true
  get_pad_mask: true
  pad_seq_length: false
  load_state: true