# @package _global_

algo:
  config:
    models:
      motion_text_embeddings_mask_model:
        _target_: phys_anim.agents.models.common.Flatten
        config:
          normalize_obs: False
          obs_clamp_value: null
          dtype: bool
        num_in: 1
        num_out: 1

      motion_text_embeddings_for_transformer_model:
        _target_: phys_anim.agents.models.mlp.MLP_WithNorm
        config:
          initializer: ${algo.config.actor.config.initializer}
          units: ${algo.config.models.prior_preprocessor_units}
          activation: ${algo.config.actor.config.activation}
          use_layer_norm: ${algo.config.actor.config.use_layer_norm}
          normalize_obs: ${algo.config.normalize_obs}
          obs_clamp_value: ${algo.config.obs_clamp_value}
          encoder_input_dim: ${env.config.masked_mimic_obs.text_embedding_dim}
          mask_key: motion_text_embeddings_mask
          mask_valid_as_zeros: False
          operations:
            - type: encode
            - type: reshape
              new_shape:
                - batch_size
                - 1
                - ${algo.config.models.motion_text_embeddings_for_transformer_model.num_out}
        num_in: ${env.config.masked_mimic_obs.text_embedding_dim}
        num_out: ${algo.config.actor.config.mu_model.config.transformer_token_size}

      prior_pre_processor:
        config:
          extra_inputs:
            motion_text_embeddings: ${algo.config.models.motion_text_embeddings_for_transformer_model}
            motion_text_embeddings_mask: ${algo.config.models.motion_text_embeddings_mask_model}

    extra_inputs:
      motion_text_embeddings:
        retrieve_from_env: True
        dtype: float
        size: ${env.config.masked_mimic_obs.text_embedding_dim}
      motion_text_embeddings_mask:
        retrieve_from_env: True
        dtype: bool
        size: 1

env:
  config:
    masked_mimic_masking:
      motion_text_embeddings_visible_prob: 0.5