# @package _global_

defaults:
  - amp

algo:
  _target_: phys_anim.agents.infomax.InfoMax
  config:
    # Setup discriminator structure
    actor:
      config:
        mu_model:
          config:
            latent_dim: ${algo.config.infomax_parameters.latent_dim}
            extra_inputs:
              latents:
                _target_: phys_anim.agents.models.mlp.MLP_WithNorm
                config:
                  initializer: ${algo.config.actor.config.initializer}
                  units: [512, 256]
                  activation: ${algo.config.actor.config.activation}
                  use_layer_norm: ${algo.config.actor.config.use_layer_norm}
                  normalize_obs: ${algo.config.normalize_obs}
                  obs_clamp_value: ${algo.config.obs_clamp_value}
                num_in: ${sum:${algo.config.infomax_parameters.latent_dim}}
                num_out: ${.num_in}

    critic:
      config:
        extra_inputs:
          latents:
            _target_: phys_anim.agents.models.common.Flatten
            config:
              normalize_obs: False
              obs_clamp_value: ${algo.config.obs_clamp_value}
            num_in: ${sum:${algo.config.infomax_parameters.latent_dim}}
            num_out: ${.num_in}

    discriminator:
      _target_: phys_anim.agents.models.infomax.JointDiscWithMutualInformationEncMLP
      config:
        shared:
          initializer: ${..initializer}
          out_dim: 512  # 1024 -> 1024 -> 512
          units: ${..units}
          activation: ${..activation}
          normalize_obs: ${algo.config.normalize_obs}
          obs_clamp_value: ${algo.config.obs_clamp_value}
          use_layer_norm: ${..use_layer_norm}
        mi_enc:
          latent_dim: ${algo.config.infomax_parameters.latent_dim}
          latent_types: ${algo.config.infomax_parameters.latent_types}
        latent_dim: ${algo.config.infomax_parameters.latent_dim}

    # ASE parameters
    infomax_parameters:
      latent_dim: [64]
      latent_types: [hypersphere]

      mi_reward_w: [0.5]
      mi_hypersphere_reward_shift: True

      mi_enc_weight_decay: 0
      mi_enc_grad_penalty: 0

      diversity_tar: 1.
      diversity_bonus: 0.01

      random_latents: True
      latent_steps_min: 1
      latent_steps_max: 150

    extra_inputs:
      latents:
        retrieve_from_env: False
        dtype: float
        size: ${sum:${algo.config.infomax_parameters.latent_dim}}
