# @package _global_

algo:
  config:
    models:
      transformer_obs_units: [256, 256]

      transformer_mimic_target_pose_model:
        _target_: phys_anim.agents.models.mlp.MLP_WithNorm
        config:
          initializer: ${algo.config.actor.config.initializer}
          units: ${algo.config.models.transformer_obs_units}
          activation: ${algo.config.actor.config.activation}
          use_layer_norm: ${algo.config.actor.config.use_layer_norm}
          normalize_obs: ${algo.config.normalize_obs}
          obs_clamp_value: ${algo.config.obs_clamp_value}
          encoder_input_dim: ${env.config.mimic_target_pose.num_obs_per_target_pose}
          operations:
            - type: reshape
              new_shape:
                - -1
                - ${env.config.mimic_target_pose.num_obs_per_target_pose}
            - type: encode
            - type: reshape
              new_shape:
                - batch_size
                - ${env.config.mimic_target_pose.num_future_steps}
                - ${algo.config.models.transformer_mimic_target_pose_model.num_out}
        num_in: ${eval:${env.config.mimic_target_pose.num_future_steps}*${env.config.mimic_target_pose.num_obs_per_target_pose}}
        num_out: ${algo.config.actor.config.mu_model.config.transformer_token_size}

      mlp_mimic_target_pose_model:
        _target_: phys_anim.agents.models.common.Flatten
        config:
          normalize_obs: ${algo.config.normalize_obs}
          obs_clamp_value: ${algo.config.obs_clamp_value}
        num_in: ${eval:${env.config.mimic_target_pose.num_future_steps}*${env.config.mimic_target_pose.num_obs_per_target_pose}}
        num_out: ${.num_in}

    actor:
      config:
        mu_model:
          _target_: phys_anim.agents.models.transformer.TransformerWithNorm
          _recursive_: False
          config:
            units: [1024, 1024, 1024]
            activation: ${algo.config.actor.config.activation}
            transformer_token_size: ${.latent_dim}
            latent_dim: 512
            ff_size: 1024
            type_embedding_dim: 0
            num_layers: 4
            num_heads: 4
            dropout: 0
            num_obs_per_target_pose: ${env.config.mimic_target_pose.num_obs_per_target_pose}
            num_future_steps: ${env.config.mimic_target_pose.num_future_steps}
            terrain_model: ${algo.config.models.terrain_models.transformer}
            obs_mlp:
              initializer: ${algo.config.actor.config.initializer}
              units: ${algo.config.models.transformer_obs_units}
              activation: ${algo.config.actor.config.activation}
              use_layer_norm: ${algo.config.actor.config.use_layer_norm}
              normalize_obs: ${algo.config.normalize_obs}
              obs_clamp_value: ${algo.config.obs_clamp_value}
            extra_inputs:
              mimic_target_poses: ${algo.config.models.transformer_mimic_target_pose_model}
            output_mlp:
              initializer: ${algo.config.actor.config.initializer}
              units: [1024, 1024, 1024]
              activation: ${algo.config.actor.config.activation}
              use_layer_norm: ${algo.config.actor.config.use_layer_norm}
              normalize_obs: False
              obs_clamp_value: null

    critic:
      config:
        units: [1024, 1024, 1024, 1024]
        extra_inputs:
          mimic_target_poses: ${algo.config.models.mlp_mimic_target_pose_model}

    extra_inputs:
      mimic_target_poses:
        retrieve_from_env: True
        dtype: float
        size: ${algo.config.models.mlp_mimic_target_pose_model.num_in}

env:
  config:
    max_episode_length: 300

    mimic_target_pose:
      enabled: True
      type: max-coords-future-rel
      with_time: True
      num_obs_per_target_pose: ${eval:${.base_num_obs_per_target_pose}+1}
      num_future_steps: 15
