# @package _global_

# Defaults for training for the pusht_keypoints dataset.

# They keypoints are on the vertices of the rectangles that make up the PushT as documented in the PushT
# environment:
# https://github.com/huggingface/gym-pusht/blob/5e2489be9ff99ed9cd47b6c653dda3b7aa844d24/gym_pusht/envs/pusht.py#L522-L534
# For completeness, the diagram is copied here:
#        0───────────1
#        │           │
#        3───4───5───2
#            │   │
#            │   │
#            │   │
#            │   │
#            7───6


# Note: The original work trains keypoints-only with conditioning via inpainting. Here, we encode the
# observation along with the agent position and use the encoding as global conditioning for the denoising
# U-Net.

# Note: We do not track EMA model weights as we discovered it does not improve the results. See
#       https://github.com/huggingface/lerobot/pull/134 for more details.

seed: 100000
dataset_repo_id: lerobot/pusht_keypoints

training:
  offline_steps: 200000
  online_steps: 0
  eval_freq: 5000
  save_freq: 5000
  log_freq: 250
  save_checkpoint: true

  batch_size: 64
  grad_clip_norm: 10
  lr: 1.0e-4
  lr_scheduler: cosine
  lr_warmup_steps: 500
  adam_betas: [0.95, 0.999]
  adam_eps: 1.0e-8
  adam_weight_decay: 1.0e-6
  online_steps_between_rollouts: 1

  delta_timestamps:
    observation.environment_state: "[i / ${fps} for i in range(1 - ${policy.n_obs_steps}, 1)]"
    observation.state: "[i / ${fps} for i in range(1 - ${policy.n_obs_steps}, 1)]"
    action: "[i / ${fps} for i in range(1 - ${policy.n_obs_steps}, 1 - ${policy.n_obs_steps} + ${policy.horizon})]"

  # The original implementation doesn't sample frames for the last 7 steps,
  # which avoids excessive padding and leads to improved training results.
  drop_n_last_frames: 7  # ${policy.horizon} - ${policy.n_action_steps} - ${policy.n_obs_steps} + 1

eval:
  n_episodes: 50
  batch_size: 50

policy:
  name: diffusion

  # Input / output structure.
  n_obs_steps: 2
  horizon: 16
  n_action_steps: 8

  input_shapes:
    # TODO(rcadene, alexander-soare): add variables for height and width from the dataset/env?
    observation.environment_state: [16]
    observation.state: ["${env.state_dim}"]
  output_shapes:
    action: ["${env.action_dim}"]

  # Normalization / Unnormalization
  input_normalization_modes:
    observation.environment_state: min_max
    observation.state: min_max
  output_normalization_modes:
    action: min_max

  # Architecture / modeling.
  # Vision backbone.
  vision_backbone: resnet18
  crop_shape: [84, 84]
  crop_is_random: True
  pretrained_backbone_weights: null
  use_group_norm: True
  spatial_softmax_num_keypoints: 32
  # Unet.
  down_dims: [256, 512, 1024]
  kernel_size: 5
  n_groups: 8
  diffusion_step_embed_dim: 128
  use_film_scale_modulation: True
  # Noise scheduler.
  noise_scheduler_type: DDIM
  num_train_timesteps: 100
  beta_schedule: squaredcos_cap_v2
  beta_start: 0.0001
  beta_end: 0.02
  prediction_type: epsilon # epsilon / sample
  clip_sample: True
  clip_sample_range: 1.0

  # Inference
  num_inference_steps: 10  # if not provided, defaults to `num_train_timesteps`

  # Loss computation
  do_mask_loss_for_padding: false
