defaults:
  - base
  - _self_

policy:
  _target_: algos.MSP_DP.MeanFlowScalePolicy
  autoencoder:
    _target_: algos.vae.vae.ActionVAE
    action_dim: ${task.shape_meta.action.shape[0]}
    encoder_dim: 128
    decoder_dim: 128
    skill_block_size: ${algo.skill_block_size}
    downsample_factor: ${algo.downsample_factor}
    attn_pdrop: 0.1
    use_causal_encoder: True
    use_causal_decoder: True
    encoder_heads: 2
    encoder_layers: 2
    decoder_heads: 4
    decoder_layers: 4
    latent_dim: ${algo.latent_action_dim}
    kl_weight: 1e-6

  flowar:
    _target_: algos.flow.flow_ar.FlowAR
    encoder_embed_dim: 256
    encoder_depth: 6
    encoder_num_heads: 4
    decoder_embed_dim: 256
    decoder_depth: 6
    decoder_num_heads: 4
    mlp_ratio: 4.0
    norm_layer:
      _target_: torch.nn.LayerNorm
      eps: 1e-6
      normalized_shape: 256
    action_dim: ${algo.latent_action_dim}
    attn_dropout: 0.1
    proj_dropout: 0.1
    scale: [ 1, 2, 4, 8 ]
    obs_in_features: ${eval:'${algo.obs_embed}+${task.shape_meta.obs.agent_pos.shape[0]}*${task.n_obs_steps}'}

  image_encoder:
    _target_: diffusion_policy.model.vision.multi_image_obs_encoder.MultiImageObsEncoder
    shape_meta: ${task.shape_meta}
    rgb_model:
      _target_: diffusion_policy.model.vision.model_getter.get_resnet
      name: resnet18
      weights: null
    resize_shape: null
    crop_shape: [ 76, 76 ]
    random_crop: True
    use_group_norm: True
    share_rgb_model: False
    imagenet_norm: True

  latent_action_dim: ${algo.latent_action_dim}
  latent_action_chunk: ${eval:'${algo.skill_block_size} // ${algo.downsample_factor}'}
  stage: ${stage}
  action_dim: ${task.shape_meta.action.shape[0]}
  action_chunk: ${algo.skill_block_size}
  n_action_steps: ${task.n_action_steps}
  n_obs_steps: ${task.n_obs_steps}

name: MSP_DP



