# @package _global_

algo:
  config:
    models:
      historical_poses_mask_model:
        _target_: phys_anim.agents.models.common.Flatten
        config:
          normalize_obs: False
          obs_clamp_value: null
          dtype: bool
        num_in: ${env.config.masked_mimic_obs.num_historical_conditioned_steps}
        num_out: ${.num_in}

      historical_poses_for_transformer_model:
        _target_: phys_anim.agents.models.mlp.MLP_WithNorm
        config:
          initializer: ${algo.config.actor.config.initializer}
          units: ${algo.config.models.prior_preprocessor_units}
          activation: ${algo.config.actor.config.activation}
          use_layer_norm: ${algo.config.actor.config.use_layer_norm}
          normalize_obs: ${algo.config.normalize_obs}
          obs_clamp_value: ${algo.config.obs_clamp_value}
          encoder_input_dim: ${env.config.masked_mimic_obs.historical_obs_size_per_pose}
          mask_key: historical_pose_obs_mask
          mask_valid_as_zeros: False
          operations:
            - type: reshape
              new_shape:
                - -1  # batch * num historical steps
                - ${env.config.masked_mimic_obs.historical_obs_size_per_pose}
            - type: encode
            - type: reshape
              new_shape:
                - batch_size
                - ${env.config.masked_mimic_obs.num_historical_conditioned_steps}
                - ${algo.config.models.historical_poses_for_transformer_model.num_out}
        num_in: ${eval:${env.config.masked_mimic_obs.num_historical_conditioned_steps}*${env.config.masked_mimic_obs.historical_obs_size_per_pose}}
        num_out: ${algo.config.actor.config.mu_model.config.transformer_token_size}

      prior_pre_processor:
        config:
          extra_inputs:
            historical_pose_obs: ${algo.config.models.historical_poses_for_transformer_model}
            historical_pose_obs_mask: ${algo.config.models.historical_poses_mask_model}

    extra_inputs:
      historical_pose_obs:
        retrieve_from_env: True
        dtype: float
        size: ${eval:${env.config.masked_mimic_obs.historical_obs_size_per_pose}*${env.config.masked_mimic_obs.num_historical_conditioned_steps}}
      historical_pose_obs_mask:
        retrieve_from_env: True
        dtype: bool
        size: ${env.config.masked_mimic_obs.num_historical_conditioned_steps}

env:
  config:
    masked_mimic_obs:
      num_historical_stored_steps: 120  # How much history do we remember
      num_historical_conditioned_steps: 15  # We subsample from the history to condition on
