# @package _global_

seed: 1
dataset_repo_id: lerobot/xarm_lift_medium

training:
  offline_steps: 25000
  # TODO(alexander-soare): uncomment when online training gets reinstated
  online_steps: 0  # 25000 not implemented yet
  eval_freq: 5000
  online_steps_between_rollouts: 1
  online_sampling_ratio: 0.5
  online_env_seed: 10000
  log_freq: 100

  batch_size: 256
  grad_clip_norm: 10.0
  lr: 3e-4

  delta_timestamps:
    observation.image: "[i / ${fps} for i in range(${policy.horizon} + 1)]"
    observation.state: "[i / ${fps} for i in range(${policy.horizon} + 1)]"
    action: "[i / ${fps} for i in range(${policy.horizon})]"
    next.reward: "[i / ${fps} for i in range(${policy.horizon})]"

policy:
  name: tdmpc

  pretrained_model_path:

  # Input / output structure.
  n_action_repeats: 2
  horizon: 5

  input_shapes:
    # TODO(rcadene, alexander-soare): add variables for height and width from the dataset/env?
    observation.image: [3, 84, 84]
    observation.state: ["${env.state_dim}"]
  output_shapes:
    action: ["${env.action_dim}"]

  # Normalization / Unnormalization
  input_normalization_modes: null
  output_normalization_modes:
    action: min_max

  # Architecture / modeling.
  # Neural networks.
  image_encoder_hidden_dim: 32
  state_encoder_hidden_dim: 256
  latent_dim: 50
  q_ensemble_size: 5
  mlp_dim: 512
  # Reinforcement learning.
  discount: 0.9

  # Inference.
  use_mpc: true
  cem_iterations: 6
  max_std: 2.0
  min_std: 0.05
  n_gaussian_samples: 512
  n_pi_samples: 51
  uncertainty_regularizer_coeff: 1.0
  n_elites: 50
  elite_weighting_temperature: 0.5
  gaussian_mean_momentum: 0.1

  # Training and loss computation.
  max_random_shift_ratio: 0.0476
  # Loss coefficients.
  reward_coeff: 0.5
  expectile_weight: 0.9
  value_coeff: 0.1
  consistency_coeff: 20.0
  advantage_scaling: 3.0
  pi_coeff: 0.5
  temporal_decay_coeff: 0.5
  # Target model.
  target_model_momentum: 0.995
