# @package _global_

defaults:
  - mimic

algo:
  _target_: phys_anim.agents.mimic_vae.MimicVAE
  config:
    # Mimic-VAE parameters
    models:
      encoder_preprocessor_units: [1024, 1024, 1024]
      encoder_preprocessor_output_dim: 1024
      prior_preprocessor_output_dim: 1024
      prior_preprocessor_units: [1024, 1024, 1024]
      prior_output_units: [256, 128]
      encoder_output_units: [512]

      encoder_pre_processor:
        _target_: phys_anim.agents.models.mlp.MultiHeadedMLP
        _recursive_: False
        config:
          initializer: ${algo.config.actor.config.initializer}
          units: ${algo.config.models.encoder_preprocessor_units}
          activation: ${algo.config.actor.config.activation}
          normalize_obs: ${algo.config.normalize_obs}
          obs_clamp_value: ${algo.config.obs_clamp_value}
          use_layer_norm: ${algo.config.actor.config.use_layer_norm}
          extra_inputs:
            terrain: ${algo.config.models.terrain_models.mlp}

        num_out: ${algo.config.models.encoder_preprocessor_output_dim}

      vae_encoder_model:
        _target_: phys_anim.agents.models.mlp.MultiOutputNetwork
        _recursive_: False
        config:
          normalize_obs: False
          obs_clamp_value: null
          trunk: ${algo.config.models.encoder_pre_processor}

          outputs:
            mu:
              _target_: phys_anim.agents.models.mlp.MLP_WithNorm
              config:
                initializer: ${algo.config.actor.config.initializer}
                units: ${algo.config.models.encoder_output_units}
                activation: ${algo.config.actor.config.activation}
                normalize_obs: False
                obs_clamp_value: null
                use_layer_norm: ${algo.config.actor.config.use_layer_norm}
              num_in: ${algo.config.models.encoder_preprocessor_output_dim}
              num_out: ${algo.config.actor.config.vae_latent_dim}

            logvar:
              _target_: phys_anim.agents.models.mlp.MLP_WithNorm
              config:
                activation: ${algo.config.actor.config.activation}
                use_layer_norm: ${algo.config.actor.config.use_layer_norm}
                initializer: ${algo.config.actor.config.initializer}
                units: ${algo.config.models.encoder_output_units}
                normalize_obs: False
                obs_clamp_value: null
              num_in: ${algo.config.models.encoder_preprocessor_output_dim}
              num_out: ${algo.config.actor.config.vae_latent_dim}
        num_out: ${algo.config.actor.config.vae_latent_dim}

      prior_pre_processor:
        _target_: phys_anim.agents.models.mlp.MultiHeadedMLP
        _recursive_: False
        config:
          initializer: ${algo.config.actor.config.initializer}
          units: ${algo.config.models.prior_preprocessor_units}
          activation: ${algo.config.actor.config.activation}
          normalize_obs: ${algo.config.normalize_obs}
          obs_clamp_value: ${algo.config.obs_clamp_value}
          use_layer_norm: ${algo.config.actor.config.use_layer_norm}
          terrain_model: ${algo.config.models.terrain_models.mlp}
          extra_inputs:
            terrain: ${..terrain_model}

        num_out: ${algo.config.models.prior_preprocessor_output_dim}

      vae_prior_model:
        _target_: phys_anim.agents.models.mlp.MultiOutputNetwork
        _recursive_: False
        config:
          normalize_obs: False
          obs_clamp_value: null
          trunk: ${algo.config.models.prior_pre_processor}

          outputs:
            mu:
              _target_: phys_anim.agents.models.mlp.MLP_WithNorm
              config:
                initializer: ${algo.config.actor.config.initializer}
                units: ${algo.config.models.prior_output_units}
                activation: ${algo.config.actor.config.activation}
                normalize_obs: False
                obs_clamp_value: null
                use_layer_norm: ${algo.config.actor.config.use_layer_norm}
              num_in: ${algo.config.models.prior_pre_processor.num_out}
              num_out: ${algo.config.actor.config.vae_latent_dim}

            logvar:
              _target_: phys_anim.agents.models.mlp.MLP_WithNorm
              config:
                initializer: ${algo.config.actor.config.initializer}
                units: ${algo.config.models.prior_output_units}
                activation: ${algo.config.actor.config.activation}
                use_layer_norm: ${algo.config.actor.config.use_layer_norm}
                normalize_obs: False
                obs_clamp_value: null
              num_in: ${algo.config.models.prior_pre_processor.num_out}
              num_out: ${algo.config.actor.config.vae_latent_dim}
        num_out: ${algo.config.actor.config.vae_latent_dim}

      vae_latent_model:
        _target_: phys_anim.agents.models.mlp.MLP_WithNorm
        config:
          activation: ${algo.config.actor.config.activation}
          use_layer_norm: ${algo.config.actor.config.use_layer_norm}
          initializer: ${algo.config.actor.config.initializer}
          units: ${algo.config.models.prior_output_units}
          normalize_obs: False
          obs_clamp_value: null
        num_in: ${algo.config.actor.config.vae_latent_dim}
        num_out: ${algo.config.actor.config.vae_latent_dim}

    actor:
      _target_: phys_anim.agents.models.actor.ActorFixedSigmaVAE
      config:
        mu_model:
          config:
            extra_inputs:
              vae_latent: ${algo.config.models.vae_latent_model}
        vae_encoder: ${algo.config.models.vae_encoder_model}
        vae_prior: ${algo.config.models.vae_prior_model}
        vae_latent_from_prior: ${vae_latent_from_prior}
        vae_latent_dim: 64
        residual_encoder: True

    vae:
      vae_kld_schedule:
        init_kld_coeff: 0.0001
        end_kld_coeff: 0.01
        start_epoch: 3000
        end_epoch: 6000

      vae_noise_type: ${vae_noise_type}

    extra_inputs:
      vae_latent:
        retrieve_from_env: False
        dtype: float
        size: ${algo.config.actor.config.vae_latent_dim}

# Globally accessible parameters
vae_latent_from_prior: False
vae_noise_type: normal

eval_overrides:
  vae_latent_from_prior: True
  vae_noise_type: zeros
