defaults:
  - base
  - _self_

policy:
  _target_: algos.MSP.MeanFlowScalePolicy
  autoencoder:
    _target_: algos.vae.vae.ActionVAE
    action_dim: ${task.shape_meta.action_dim}
    encoder_dim: 128
    decoder_dim: 128
    skill_block_size: ${algo.skill_block_size}
    downsample_factor: ${algo.downsample_factor}
    attn_pdrop: 0.1
    use_causal_encoder: True
    use_causal_decoder: True
    encoder_heads: 2
    encoder_layers: 2
    decoder_heads: 4
    decoder_layers: 4
    latent_dim: ${algo.latent_action_dim}
    kl_weight: 1e-6
  flowar:
    _target_: algos.flow.flow_ar.FlowAR
    encoder_embed_dim: 256
    encoder_depth: 6
    encoder_num_heads: 4
    decoder_embed_dim: 256
    decoder_depth: 6
    decoder_num_heads: 4
    mlp_ratio: 4.0
    norm_layer:
      _target_: torch.nn.LayerNorm
      eps: 1e-6
      normalized_shape: 256
    action_dim: ${algo.latent_action_dim}
    attn_dropout: 0.1
    proj_dropout: 0.1
    scale: [1, 2, 4, 8]
    obs_in_features: 1033

  image_encoder:
    _target_: algos.MidPlanner.DnceLatentProj
    latent_info_file: "assets/libero.pkl"

  latent_action_dim: ${algo.latent_action_dim}
  latent_action_chunk: ${eval:'${algo.skill_block_size} // ${algo.downsample_factor}'}
  stage: ${stage}
  action_dim: ${task.shape_meta.action_dim}
  action_chunk: ${algo.skill_block_size}

name: MSP


