hydra:
  run:
    dir: outputs/train/${now:%H-%M-%S}_${env.name}_${policy.name}_${hydra.job.name}
  job:
    name: default

resume: false
device: cpu # cuda:1
use_amp: false
seed: 1000

dataset_repo_id: lerobot/aloha_sim_insertion_human # FIXME
dataset_root: ./data/
video_backend: pyav

wandb:
  enable: false
  # Set to true to disable saving an artifact despite save_checkpoint == True
  disable_artifact: true
  project: aloha
  notes: ""
  log_video: false

fps: 50
env:
  name: aloha
  task: AlohaInsertion-v0  # FIXME
  state_dim: 14
  action_dim: 14
  fps: ${fps}
  episode_length: 400
  gym:
    obs_type: pixels_agent_pos
    render_mode: rgb_array


training:
  offline_steps: 100000
  online_steps: 0
  eval_freq: 10000
  save_freq: 20000
  log_freq: 200

  lr_scheduler: default 
  lr_warmup_steps: 500


  save_checkpoint: true
  num_workers: 4

  batch_size: 8
  lr: 1e-5
  lr_backbone: 1e-5
  weight_decay: 1e-4
  grad_clip_norm: 10
  online_steps_between_rollouts: 1

  delta_timestamps:
    action: "[i / ${fps} for i in range(${policy.chunk_size})]"
    left_pts_2d: "[i / ${fps} for i in range(${policy.chunk_size})]"
    right_pts_2d: "[i / ${fps} for i in range(${policy.chunk_size})]"

  online_sampling_ratio: 0.5
  online_env_seed: ???

  image_transforms:
    enable: false
    # This is the maximum number of transforms (sampled from these below) that will be applied to each frame.
    # It's an integer in the interval [1, number of available transforms].
    max_num_transforms: 3
    # By default, transforms are applied in Torchvision's suggested order (shown below).
    # Set this to True to apply them in a random order.
    random_order: false
    brightness:
      weight: 1
      min_max: [0.8, 1.2]
    contrast:
      weight: 1
      min_max: [0.8, 1.2]
    saturation:
      weight: 1
      min_max: [0.5, 1.5]
    hue:
      weight: 1
      min_max: [-0.05, 0.05]
    sharpness:
      weight: 1
      min_max: [0.8, 1.2]


eval:
  n_episodes: 50
  batch_size: 50
  use_async_envs: false


# See `configuration_act.py` for more details.
policy:
  name: autoregressive_policy

  # Input / output structure.
  n_obs_steps: 1
  chunk_size: 100 # chunk_size
  n_action_steps: 100
  n_action_steps_eval: 100

  input_shapes:
    observation.images.top: [3, 480, 640]
    observation.state: ["${env.state_dim}"]
  output_shapes:
    action: ["${env.action_dim}"]

  # Normalization / Unnormalization
  input_normalization_modes:
    observation.images.top: mean_std
    observation.state: mean_std
  output_normalization_modes:
    action: mean_std

  # Architecture.
  # Vision backbone.
  vision_backbone: resnet18
  pretrained_backbone_weights: ResNet18_Weights.IMAGENET1K_V1
  replace_final_stride_with_dilation: false
  # Transformer layers.
  pre_norm: false
  dim_model: 512
  n_heads: 8
  dim_feedforward: 3100
  feedforward_activation: relu
  n_encoder_layers: 4
  dropout: 0.1

  num_guide_points: 10
  guide_pts_downsample: 1
  num_latents: 1
  guide_pts_corr_dim: 128
  guide_pts_heatmap_sigma: 1.5
  sample: False

  arp_cfg: 
    n_embd: 512  # or 256, 128
    embd_pdrop: 0.1 

    num_layers: 4
    layer_cfg:
      n_head: 8
      mlp_ratio: 4.0
      AdaLN: True
      mlp_dropout: 0.1
      attn_kwargs: { "attn_pdrop": 0.1, "resid_pdrop": 0.1 }
      cond_attn_kwargs: { "attn_pdrop": 0.1, "resid_pdrop": 0.1 }


