# @package _global_

algo:
  config:
    models:
      # Append the path obs to the actor and critic inputs
      extra_input_obs_size: ${env.config.task_obs_size} # if /opt/inversion/pose_obs then task_obs + current_pose_size
      extra_input_model:
        _target_: phys_anim.agents.models.common.Flatten
        config:
          normalize_obs: False
          obs_clamp_value: null
        # Input and output size are the same due to flattening
        num_in: ${algo.config.models.extra_input_obs_size}
        num_out: ${algo.config.models.extra_input_obs_size}

      extra_input_model_for_transformer:
        _target_: phys_anim.agents.models.mlp.MLP_WithNorm
        config:
          initializer: ${algo.config.actor.config.initializer}
          units: ${algo.config.models.prior_preprocessor_units}
          activation: ${algo.config.actor.config.activation}
          use_layer_norm: ${algo.config.actor.config.use_layer_norm}
          normalize_obs: ${algo.config.normalize_obs}
          obs_clamp_value: ${algo.config.obs_clamp_value}
          operations:
            - type: encode
            - type: reshape
              new_shape:
                - batch_size
                - 1
                - ${algo.config.models.motion_text_embeddings_for_transformer_model.num_out}
          encoder_input_dim: ${algo.config.models.extra_input_obs_size}
        num_in: ${algo.config.models.extra_input_obs_size}
        num_out: ${algo.config.actor.config.mu_model.config.transformer_token_size}

      prior_pre_processor:  # Transformer inputs
        config:
          extra_inputs:
            inversion_obs: ${algo.config.models.extra_input_model_for_transformer}

    critic:
      _target_: phys_anim.agents.models.critic.CriticMLP
      _recursive_: False
      config:
        initializer: default
        units: [1024, 1024, 1024]
        activation: relu
        normalize_obs: ${algo.config.normalize_obs}
        obs_clamp_value: ${algo.config.obs_clamp_value}
        use_layer_norm: False
        extra_inputs:
           inversion_obs: ${algo.config.models.extra_input_model}

    extra_inputs:
         inversion_obs:
           retrieve_from_env: True
           dtype: float
           size: ${algo.config.models.extra_input_model.num_in}

env:
  config:
    state_init: Default
    enable_height_termination: True
    mimic_reset_track:
      reset_on_episode_reset: False
      reset_episode_on_reset_track: False
    current_pose_obs_type: null
    current_pose_obs_size: 0

wandb:
  wandb_tags: [inversion]