# @package _global_

# Defaults for training for the PushT dataset.

seed: 100000
dataset_repo_id: lerobot/pusht

override_dataset_stats:
  # TODO(rcadene, alexander-soare): should we remove image stats as well? do we use a pretrained vision model?
  observation.image:
    mean: [[[0.5]], [[0.5]], [[0.5]]]  # (c,1,1)
    std: [[[0.5]], [[0.5]], [[0.5]]]  # (c,1,1)
  # TODO(rcadene, alexander-soare): we override state and action stats to use the same as the pretrained model
  # from the original codebase, but we should remove these and train our own pretrained model
  observation.state:
    min: [13.456424, 32.938293]
    max: [496.14618, 510.9579]
  action:
    min: [12.0, 25.0]
    max: [511.0, 511.0]

training:
  offline_steps: 250000
  online_steps: 0
  eval_freq: 25000
  save_freq: 25000
  save_checkpoint: true

  batch_size: 64
  grad_clip_norm: 10
  lr: 1.0e-4
  lr_scheduler: cosine
  lr_warmup_steps: 500
  adam_betas: [0.95, 0.999]
  adam_eps: 1.0e-8
  adam_weight_decay: 1.0e-6
  online_steps_between_rollouts: 1

  # VQ-BeT specific
  vqvae_lr: 1.0e-3
  n_vqvae_training_steps: 20000
  bet_weight_decay: 2e-4
  bet_learning_rate: 5.5e-5
  bet_betas: [0.9, 0.999]

  delta_timestamps:
    observation.image: "[i / ${fps} for i in range(1 - ${policy.n_obs_steps}, 1)]"
    observation.state: "[i / ${fps} for i in range(1 - ${policy.n_obs_steps}, 1)]"
    action: "[i / ${fps} for i in range(1 - ${policy.n_obs_steps}, ${policy.n_action_pred_token} + ${policy.action_chunk_size} - 1)]"

eval:
  n_episodes: 50
  batch_size: 50

policy:
  name: vqbet

  # Input / output structure.
  n_obs_steps: 5
  n_action_pred_token: 7
  action_chunk_size: 5

  input_shapes:
    # TODO(rcadene, alexander-soare): add variables for height and width from the dataset/env?
    observation.image: [3, 96, 96]
    observation.state: ["${env.state_dim}"]
  output_shapes:
    action: ["${env.action_dim}"]

  # Normalization / Unnormalization
  input_normalization_modes:
    observation.image: mean_std
    observation.state: min_max
  output_normalization_modes:
    action: min_max

  # Architecture / modeling.
  # Vision backbone.
  vision_backbone: resnet18
  crop_shape: [84, 84]
  crop_is_random: True
  pretrained_backbone_weights: null
  use_group_norm: True
  spatial_softmax_num_keypoints: 32
  # VQ-VAE
  n_vqvae_training_steps: ${training.n_vqvae_training_steps}
  vqvae_n_embed: 16
  vqvae_embedding_dim: 256
  vqvae_enc_hidden_dim: 128
  # VQ-BeT
  gpt_block_size: 500
  gpt_input_dim: 512
  gpt_output_dim: 512
  gpt_n_layer: 8
  gpt_n_head: 8
  gpt_hidden_dim: 512
  dropout: 0.1
  mlp_hidden_dim: 1024
  offset_loss_weight: 10000.
  primary_code_loss_weight: 5.0
  secondary_code_loss_weight: 0.5
  bet_softmax_temperature: 0.1
  sequentially_select: False
