# @package _global_

algo:
  config:
    models:
      transformer_obs_units: ${.prior_preprocessor_units}
      encoder_preprocessor_units: [ 1024, 1024, 1024 ]
      encoder_preprocessor_output_dim: 1024
      prior_preprocessor_output_dim: ${algo.config.models.prior_pre_processor.config.latent_dim}
      prior_output_units: [ 256, 128 ]
      encoder_output_units: [ 512 ]
      prior_preprocessor_units: [256, 256]

      mimic_target_pose_model:
        _target_: phys_anim.agents.models.common.Flatten
        config:
          normalize_obs: ${algo.config.normalize_obs}
          obs_clamp_value: ${algo.config.obs_clamp_value}
        num_in: ${eval:${env.config.mimic_target_pose.num_future_steps}*${env.config.mimic_target_pose.num_obs_per_target_pose}}
        num_out: ${.num_in}

      masked_mimic_target_bodies_mask_model:
        _target_: phys_anim.agents.models.common.Flatten
        config:
          normalize_obs: False
          obs_clamp_value: null
          dtype: bool
        num_in: ${eval:${eval:${len:${env.config.masked_mimic_conditionable_bodies}}+1}*2*${env.config.mimic_target_pose.num_future_steps}}
        num_out: ${.num_in}

      masked_mimic_target_poses_mask_model:
        _target_: phys_anim.agents.models.common.Flatten
        config:
          normalize_obs: False
          obs_clamp_value: null
          dtype: bool
        num_in: ${env.config.masked_mimic_obs.masked_mimic_target_poses_num_steps}
        num_out: ${.num_in}

      masked_mimic_target_pose_model:
        _target_: phys_anim.agents.models.mlp.MLP_WithNorm
        config:
          initializer: ${algo.config.actor.config.initializer}
          units: ${algo.config.models.prior_preprocessor_units}
          activation: ${algo.config.actor.config.activation}
          use_layer_norm: ${algo.config.actor.config.use_layer_norm}
          normalize_obs: ${algo.config.normalize_obs}
          obs_clamp_value: ${algo.config.obs_clamp_value}
          encoder_input_dim: ${env.config.masked_mimic_obs.num_obs_per_sparse_target_pose}
          mask_key: masked_mimic_target_poses_masks
          mask_valid_as_zeros: False
          obs_per_body_controller: 12
          operations:
            - type: reshape
              new_shape:
                - -1
                - ${env.config.masked_mimic_obs.num_obs_per_sparse_target_pose}  # encoded obs per pose
            - type: encode
            - type: reshape
              new_shape:
                - batch_size
                - ${env.config.masked_mimic_obs.masked_mimic_target_poses_num_steps}
                - ${algo.config.models.masked_mimic_target_pose_model.num_out}  # encoded obs per pose
        num_in: ${eval:${env.config.masked_mimic_obs.masked_mimic_target_poses_num_steps}*${env.config.masked_mimic_obs.num_obs_per_sparse_target_pose}}
        num_out: ${algo.config.actor.config.mu_model.config.transformer_token_size}

      encoder_pre_processor:
        config:
          extra_inputs:
            mimic_target_poses: ${algo.config.models.mimic_target_pose_model}
            masked_mimic_target_bodies_masks: ${algo.config.models.masked_mimic_target_bodies_mask_model}

      prior_pre_processor:
        _target_: phys_anim.agents.models.transformer.TransformerWithNorm
        _recursive_: False
        config:
          activation: ${algo.config.actor.config.activation}
          latent_dim: 512
          ff_size: 1024
          type_embedding_dim: 0
          num_layers: 4
          num_heads: 4
          dropout: 0
          num_obs_per_target_pose: ${env.config.masked_mimic_obs.num_obs_per_sparse_target_pose}
          num_future_steps: ${env.config.mimic_target_pose.num_future_steps}
          output_decoder: False
          terrain_model: ${algo.config.models.terrain_models.transformer}
          obs_mlp:
            initializer: ${algo.config.actor.config.initializer}
            units: ${algo.config.models.prior_preprocessor_units}
            activation: ${algo.config.actor.config.activation}
            use_layer_norm: ${algo.config.actor.config.use_layer_norm}
            normalize_obs: ${algo.config.normalize_obs}
            obs_clamp_value: ${algo.config.obs_clamp_value}
          extra_inputs:
            masked_mimic_target_poses: ${algo.config.models.masked_mimic_target_pose_model}
            masked_mimic_target_poses_masks: ${algo.config.models.masked_mimic_target_poses_mask_model}
          output_mlp:
            initializer: ${algo.config.actor.config.initializer}
            units: [ 1 ]
            activation: ${algo.config.actor.config.activation}
            use_layer_norm: ${algo.config.actor.config.use_layer_norm}
            normalize_obs: False
            obs_clamp_value: null
            num_out: ${algo.config.actor.config.vae_latent_dim}
        num_out: ${algo.config.models.prior_pre_processor.config.latent_dim}

    actor:
      config:
        mu_model:
          config:
            transformer_token_size: ${algo.config.models.prior_pre_processor.config.latent_dim}
            units: [1024, 1024, 1024]

    critic:
      config:
        units: [1]
        extra_inputs:
          mimic_target_poses: ${algo.config.models.mimic_target_pose_model}
          masked_mimic_target_bodies_masks: ${algo.config.models.masked_mimic_target_bodies_mask_model}

    extra_inputs:
      mimic_target_poses:
        retrieve_from_env: True
        dtype: float
        size: ${algo.config.models.mimic_target_pose_model.num_in}
      masked_mimic_target_poses:
        retrieve_from_env: True
        dtype: float
        size: ${algo.config.models.masked_mimic_target_pose_model.num_in}
      masked_mimic_target_poses_masks:
        retrieve_from_env: True
        dtype: bool
        size: ${algo.config.models.masked_mimic_target_poses_mask_model.num_in}
      masked_mimic_target_bodies_masks:
        retrieve_from_env: True
        dtype: bool
        size: ${algo.config.models.masked_mimic_target_bodies_mask_model.num_in}

env:
  config:
    mimic_target_pose:
      enabled: True
      num_future_steps: 10
      target_pose_type: max-coords
      with_time: False
    masked_mimic_obs:
      num_obs_per_sparse_target_pose: ${eval:${.num_obs_per_target_pose}*${eval:${len:${..masked_mimic_conditionable_bodies}}+1}//${robot.num_bodies}+${eval:${len:${..masked_mimic_conditionable_bodies}}+1}*2+2}
      masked_mimic_target_poses_num_steps: ${eval:${env.config.mimic_target_pose.num_future_steps}+1}
