reps:
  pretrain_steps: 0
  latent_dim: 2
  n_actions: 4
  dropout: ${dropout}
  base:
    n_actions: ${reps.n_actions}
    latent_dim: ${reps.latent_dim}
    encoder:
      type: mlp
      hidden_dims: [128,]
      activation: silu
      normalize: "false"
      outact: none
      output_dim: ${reps.base.latent_dim}
      dropout: ${reps.dropout}

    encoder_pixel:
      type: residual_encoder
      mlp_layers:
        - 256
        - 256
        - ${reps.base.latent_dim}
      depth: 24
      mlp_activation: silu
      cnn_activation: silu
      min_resolution: 4
      cnn_blocks: 2
      output_dim: ${reps.base.latent_dim}
      dropout: ${reps.dropout}
  
  autoencoder:
    n_actions: ${reps.n_actions}
    latent_dim: ${reps.latent_dim}
    vars_per_factor: 1
    encoder:
      type: mlp
      hidden_dims: [128, 128]
      activation: silu
      normalize: "false"
      outact: none
      output_dim: ${reps.autoencoder.latent_dim}
      dropout: ${reps.dropout}
    decoder:
      type: 'mlp'
      hidden_dims: [128, 128]
      activation: silu
      normalize: "false"
      outact: 'none'
      input_dim: ${reps.autoencoder.latent_dim}

    encoder_pixel:
      type: residual_encoder
      mlp_layers:
        - 256
        - 256
        - ${reps.base.latent_dim}
      depth: 24
      mlp_activation: silu
      cnn_activation: silu
      min_resolution: 4
      cnn_blocks: 2
      output_dim: ${reps.base.latent_dim}
      dropout: ${reps.dropout}

    decoder_pixel:
      type: residual_decoder
      mlp_layers:
        - ${reps.autoencoder.latent_dim}
        - 256 
        - 256 
      depth: 24
      mlp_activation: silu
      cnn_activation: silu
      min_resolution: 4
      cnn_blocks: 2
    params:
      recons_const: 1.

  vae:
    n_actions: ${reps.n_actions}
    latent_dim: ${reps.latent_dim}
    vars_per_factor: 1
    beta: 1.0  # Beta parameter for KL weight (set to 1.0 for standard VAE)
    encoder:
      type: mlp
      hidden_dims: [128, 128]
      activation: silu
      normalize: "false"
      outact: none
      # output_dim will be set to 2 * latent_dim in the VAE implementation
      dropout: ${reps.dropout}
    decoder:
      type: 'mlp'
      hidden_dims: [128, 128]
      activation: silu
      normalize: "false"
      outact: 'none'
      input_dim: ${reps.vae.latent_dim}

    encoder_pixel:
      type: residual_encoder
      mlp_layers:
        - 256
        - 256
        - ${eval:2*${reps.vae.latent_dim}}  # Double the output dims for mean and log_std
      depth: 24
      mlp_activation: silu
      cnn_activation: silu
      min_resolution: 4
      cnn_blocks: 2
      dropout: ${reps.dropout}
      final_activation: 'none'

    decoder_pixel:
      type: residual_decoder
      mlp_layers:
        - ${reps.vae.latent_dim}
        - 256 
        - 256 
      depth: 24
      mlp_activation: silu
      cnn_activation: silu
      min_resolution: 4
      cnn_blocks: 2
    params:
      recons_const: 1.0  # Weight for reconstruction loss
      kl_const: 1.0  # Weight for KL loss

  markov:
    n_actions: ${reps.n_actions}
    latent_dim: ${reps.latent_dim}
    smoothness_thresh: 0.01
    ratio_batch_size: 128
    vars_per_factor: 1
    encoder:
      type: mlp
      hidden_dims: [128, 128]
      activation: silu
      normalize: "false"
      outact: tanh
      output_dim: ${reps.markov.latent_dim}
      dropout: ${reps.dropout}
    
    encoder_pixel:
      type: residual_encoder
      mlp_layers:
        - 256
        - 256
        - ${reps.base.latent_dim}
      depth: 24
      mlp_activation: silu
      cnn_activation: silu
      min_resolution: 4
      cnn_blocks: 2
      output_dim: ${reps.markov.latent_dim}
      dropout: ${reps.dropout}
    inverse:
      type: mlp
      hidden_dims: [128, 128]
      activation: silu
      normalize: "false"
      outact: none
    ratio:
      type: mlp
      hidden_dims: [128, 128]
      activation: silu
      normalize: "false"
      outact: none
      output_dim: 1
    params:
      inverse_const: 10.
      ratio_const: 1.
      smoothness_const: 1.
  
  acf:
    n_actions: ${reps.n_actions}
    use_action_weights: true
    per_factor: false
    noise_std: 5e-3
    latent_dim: ${reps.latent_dim}
    vars_per_factor: 1
    batch_size: 128
    info_nce: true
    encoder:
      type: 'mlp'
      hidden_dims: [512, 512]
      activation: silu
      normalize: "false"
      outact: 'tanh'
      output_dim: ${reps.acf.latent_dim}
      dropout: ${reps.dropout}
    
    decoder:
      type: 'mlp'
      hidden_dims: [512, 512]
      activation: silu
      normalize: "false"
      outact: 'none'
      
    energy:
      type: 'mlp'
      hidden_dims: [256, ]
      activation: silu
      normalize: "false"
      outact: 'none'
      output_dim: ${reps.n_actions}

    projector:
      type: 'mlp'
      hidden_dims: [128,]
      activation: silu
      normalize: "false"
      outact: none
      output_dim: ${eval:${reps.acf.latent_dim}*${reps.acf.vars_per_factor}}

    inverse:
      type: 'mlp'
      hidden_dims: [128,]
      activation: silu
      normalize: "false"
      outact: none
      input_dim: ${eval:2*${reps.acf.latent_dim}*${reps.acf.vars_per_factor}}
      output_dim: ${reps.n_actions}

    dynamics:
      type: 'mlp'
      hidden_dims: [256,]
      activation: silu
      normalize: "false"
      outact: 'none'
      output_dim: ${eval:${reps.n_actions}*${reps.acf.latent_dim}*${reps.acf.vars_per_factor}}
    pi:
      type: 'mlp'
      hidden_dims: [256, 256]
      activation: silu
      normalize: "false"
      outact: 'none'
      output_dim: ${reps.n_actions}

    encoder_pixel:
      type: residual_encoder
      mlp_layers:
        - 256
        - 256
        - ${reps.acf.latent_dim}
      depth: 24
      mlp_activation: silu
      cnn_activation: silu
      min_resolution: 4
      cnn_blocks: 2
      dropout: ${reps.dropout}
      final_activation: tanh

    decoder_pixel:
      type: residual_decoder
      mlp_layers:
        - ${reps.acf.latent_dim}
        - 256 
        - 256 
      depth: 24
      mlp_activation: silu
      cnn_activation: silu
      min_resolution: 4
      cnn_blocks: 2
    params:
        recons_const: 0.
        inverse_const: 1. # factor loss (inverse is a misnomer. leaving for compatibility)
        inverse_model_const: 1. # real inverse loss
        forward_const: 1.
        policy_const: 1.
        per_action_forward_const: 0.
        grounding_const: 0.

  multistep_acf:
    n_actions: ${reps.n_actions}
    latent_dim: ${reps.latent_dim}
    vars_per_factor: 1
    acf_config: ${reps.acf}
    multistep_classifier:
      type: 'mlp'
      hidden_dims: [512, 512]
      activation: silu
      normalize: "false"
      outact: 'none'
      output_dim: ${reps.n_actions}
    embed:
      dim: 128
      max_offset: 16
    params:
      multistep_inv_const: 1.

  # perfactor_acf:
  #   n_actions: ${reps.n_actions}
  #   use_action_weights: false
  #   per_factor: true
  #   noise_std: 5e-3
  #   latent_dim: ${reps.latent_dim}
  #   vars_per_factor: 1
  #   batch_size: 128
  #   info_nce: true
  #   encoder:
  #     type: 'mlp'
  #     hidden_dims: [512, 512]
  #     activation: silu
  #     normalize: "false"
  #     outact: 'tanh'
  #     output_dim: ${reps.perfactor_acf.latent_dim}
  #     dropout: ${reps.dropout}
  #   decoder:
  #     type: 'mlp'
  #     hidden_dims: [512, 512]
  #     activation: silu
  #     normalize: "false"
  #     outact: 'none'
  #   energy:
  #     type: 'mlp'
  #     hidden_dims: [256, ]
  #     activation: silu
  #     normalize: "false"
  #     outact: 'none'
  #     output_dim: ${reps.n_actions}
  #   projector:
  #     type: 'mlp'
  #     hidden_dims: [128,]
  #     activation: silu
  #     normalize: "false"
  #     outact: none
  #     output_dim: ${eval:${reps.perfactor_acf.latent_dim}*${reps.perfactor_acf.vars_per_factor}}
  #   inverse:
  #     type: 'mlp'
  #     hidden_dims: [128,]
  #     activation: silu
  #     normalize: "false"
  #     outact: none
  #     input_dim: ${eval:2*${reps.perfactor_acf.latent_dim}*${reps.perfactor_acf.vars_per_factor}}
  #     output_dim: ${reps.n_actions}
  #   dynamics:
  #     type: 'mlp'
  #     hidden_dims: [256,]
  #     activation: silu
  #     normalize: "false"
  #     outact: 'none'
  #     output_dim: ${eval:${reps.n_actions}*${reps.perfactor_acf.latent_dim}*${reps.perfactor_acf.vars_per_factor}}
  #   pi:
  #     type: 'mlp'
  #     hidden_dims: [256, 256]
  #     activation: silu
  #     normalize: "false"
  #     outact: 'none'
  #     output_dim: ${reps.n_actions}
  #   encoder_pixel:
  #     type: residual_encoder
  #     mlp_layers:
  #       - 256
  #       - 256
  #       - ${reps.perfactor_acf.latent_dim}
  #     depth: 24
  #     mlp_activation: silu
  #     cnn_activation: silu
  #     min_resolution: 4
  #     cnn_blocks: 2
  #     dropout: ${reps.dropout}
  #     final_activation: tanh
  #   decoder_pixel:
  #     type: residual_decoder
  #     mlp_layers:
  #       - ${reps.perfactor_acf.latent_dim}
  #       - 256 
  #       - 256 
  #     depth: 24
  #     mlp_activation: silu
  #     cnn_activation: silu
  #     min_resolution: 4
  #     cnn_blocks: 2
  #   params:
  #     recons_const: 0.
  #     inverse_const: 1.
  #     policy_const: 0.
  
  # discreteacf:

  #   n_actions: ${reps.n_actions}
  #   use_action_weights: false
  #   per_factor: false
  #   noise_std: 5e-2
  #   latent_dim: ${reps.latent_dim}
  #   vars_per_factor: 1
  #   batch_size: 128
  #   info_nce: true
  #   n_values: 20
  #   encoder:
  #     type: 'mlp'
  #     hidden_dims: [512, 512]
  #     activation: silu
  #     normalize: "false"
  #     outact: 'tanh'
  #     output_dim: ${reps.discreteacf.latent_dim}
  #     dropout: ${reps.dropout}
  #   decoder:
  #     type: 'mlp'
  #     hidden_dims: [512, 512]
  #     activation: silu
  #     normalize: "false"
  #     outact: 'none'
  #   energy:
  #     type: 'mlp'
  #     hidden_dims: [128, 128]
  #     activation: silu
  #     normalize: "false"
  #     outact: 'none'
  #     output_dim: ${eval:${reps.n_actions}*${reps.discreteacf.n_values}}

  #   projector:
  #     type: 'mlp'
  #     hidden_dims: [128,]
  #     activation: silu
  #     normalize: "false"
  #     outact: none
  #     output_dim: ${eval:${reps.discreteacf.latent_dim}*${reps.discreteacf.vars_per_factor}}

    
  #   inverse:
  #     type: 'mlp'
  #     hidden_dims: [128,]
  #     activation: silu
  #     normalize: "false"
  #     outact: none
  #     input_dim: ${eval:2*${reps.discreteacf.latent_dim}*${reps.discreteacf.vars_per_factor}}
  #     output_dim: ${reps.n_actions}

  #   pi:
  #     type: 'mlp'
  #     hidden_dims: [256, 256]
  #     activation: silu
  #     normalize: "false"
  #     outact: 'none'
  #     output_dim: ${reps.n_actions}

  #   encoder_pixel:
  #     type: residual_encoder
  #     mlp_layers:
  #       - 256
  #       - 256
  #       - ${reps.discreteacf.latent_dim}
  #     depth: 24
  #     mlp_activation: silu
  #     cnn_activation: silu
  #     min_resolution: 4
  #     cnn_blocks: 2
  #     dropout: ${reps.dropout}

  #   decoder_pixel:
  #     type: residual_decoder
  #     mlp_layers:
  #       - ${reps.discreteacf.latent_dim}
  #       - 256 
  #       - 256 
  #     depth: 24
  #     mlp_activation: silu
  #     cnn_activation: silu
  #     min_resolution: 4
  #     cnn_blocks: 2
  #   params:
  #       recons_const: 1e-1
  #       inverse_const: 1.
  #       forward_const: 1.
  #       policy_const: 0.1

  # detacf:

  #   batch_size: 128
  #   is_pixel: false
  #   use_action_weights: false
  #   per_factor: false
  #   noise_std: 1e-2
  #   latent_dim: ${reps.latent_dim}
  #   n_actions: ${reps.n_actions}
  #   vars_per_factor: 1
  #   info_nce: true
  #   encoder:
  #     type: 'mlp'
  #     hidden_dims: [512, 512]
  #     activation: silu
  #     normalize: "false"
  #     outact: 'tanh'
  #     output_dim: ${eval:${reps.detacf.latent_dim}*${reps.detacf.vars_per_factor}}
  #   decoder:
  #     type: 'mlp'
  #     hidden_dims: [512, 512]
  #     activation: silu
  #     normalize: "false"
  #     outact: 'none'
    
  #   dynamics:
  #     type: 'mlp'
  #     hidden_dims: [128, 128]
  #     activation: silu
  #     normalize: "false"
  #     outact: 'none'
  #     output_dim: ${eval:${reps.n_actions}*${reps.detacf.latent_dim}*${reps.detacf.vars_per_factor}}
  #   pi:
  #     type: 'mlp'
  #     hidden_dims: [256, 256]
  #     activation: silu
  #     normalize: "false"
  #     outact: 'none'
  #     output_dim: ${reps.n_actions}

  #   encoder_pixel:
  #     type: residual_encoder
  #     mlp_layers:
  #       - 256
  #       - 256
  #       - ${reps.detacf.latent_dim}
  #     depth: 24
  #     mlp_activation: silu
  #     cnn_activation: silu
  #     min_resolution: 4
  #     cnn_blocks: 2

  #   decoder_pixel:
  #     type: residual_decoder
  #     mlp_layers:
  #       - ${reps.detacf.latent_dim}
  #       - 256 
  #       - 256 
  #     depth: 24
  #     mlp_activation: silu
  #     cnn_activation: silu
  #     min_resolution: 4
  #     cnn_blocks: 2
  #   params:
  #       recons_const: 1.
  #       inverse_const: 10.
  #       forward_const: 0.
  #       policy_const: 0.1

  dms:
    n_actions: ${reps.n_actions}
    latent_dim: ${reps.latent_dim}
    vars_per_factor: 1
    encoder:
      type: 'mlp'
      hidden_dims: [512, 512]
      activation: silu
      normalize: "false"
      outact: 'none'
      output_dim: ${reps.dms.latent_dim}
      dropout: ${reps.dropout}
    
    decoder:
      type: 'mlp'
      hidden_dims: [512, 512]
      activation: silu
      normalize: "false"
      outact: 'none'
      
    transition:
      type: 'mlp'
      hidden_dims: [256, ]
      activation: silu
      normalize: "false"
      outact: 'none'

    encoder_pixel:
      type: residual_encoder
      mlp_layers:
        - 256
        - 256
        - ${eval:${reps.dms.latent_dim}*2}
      depth: 24
      mlp_activation: silu
      cnn_activation: silu
      min_resolution: 4
      cnn_blocks: 2
      dropout: ${reps.dropout}
      final_activation: none

    decoder_pixel:
      type: residual_decoder
      mlp_layers:
        - ${reps.dms.latent_dim}
        - 256 
        - 256 
      depth: 24
      mlp_activation: silu
      cnn_activation: silu
      min_resolution: 4
      cnn_blocks: 2
    
    params:
        gumbel_temp: 1.
        elbo_const: 1.
        l2_reg_const: 1e-3
        g_action_const: 1.
        g_time_const: 1.

  ivae:
    n_actions: ${reps.n_actions}
    latent_dim: ${reps.latent_dim}
    vars_per_factor: 1
    hidden_dim: 256
    encoder:
      type: 'mlp'
      hidden_dims: [512, 512]
      activation: silu
      normalize: "false"
      outact: 'none'
      output_dim: ${reps.ivae.hidden_dim}
      dropout: ${reps.dropout}
    
    decoder:
      type: 'mlp'
      hidden_dims: [512, 512]
      activation: silu
      normalize: "false"
      outact: 'none'
      output_dim: ${reps.ivae.hidden_dim}
      
    inference:
      type: 'mlp'
      hidden_dims: [256, ]
      activation: silu
      normalize: "false"
      outact: 'none'
      output_dim: ${eval:${reps.ivae.latent_dim}*2}
      input_dim: ${eval:${reps.ivae.latent_dim}+${reps.n_actions}+${reps.ivae.hidden_dim}}

    encoder_pixel:
      type: residual_encoder
      mlp_layers:
        - 256
        - 256
        - ${reps.ivae.hidden_dim}
      depth: 24
      mlp_activation: silu
      cnn_activation: silu
      min_resolution: 4
      cnn_blocks: 2
      dropout: ${reps.dropout}
      final_activation: none

    decoder_pixel:
      type: residual_decoder
      mlp_layers:
        - ${reps.ivae.hidden_dim}
        - 256 
        - 256 
      depth: 24
      mlp_activation: silu
      cnn_activation: silu
      min_resolution: 4
      cnn_blocks: 2
    
    params:
      elbo_const: 1.

  recurrentacf:
    n_actions: ${reps.n_actions}
    use_action_weights: false
    per_factor: false
    noise_std: 5e-3
    latent_dim: ${reps.latent_dim}
    hidden_dim: 128
    vars_per_factor: 1
    batch_size: 128
    info_nce: true
    recurrent: true
    encoder:
      type: 'mlp'
      hidden_dims: [512, 512]
      activation: silu
      normalize: "false"
      outact: 'none'
      output_dim: ${reps.recurrentacf.hidden_dim}
      dropout: ${reps.dropout}
    
    decoder:
      type: 'mlp'
      hidden_dims: [512, 512]
      activation: silu
      normalize: "false"
      outact: 'none'
      
    energy:
      type: 'mlp'
      hidden_dims: [256, ]
      activation: silu
      normalize: "false"
      outact: 'none'
      output_dim: ${reps.n_actions}

    memory:
      type: 'gru'
      input_dim: ${reps.recurrentacf.latent_dim}
      hidden_dim: ${reps.recurrentacf.hidden_dim}
    
    memory_action:
      type: 'mlp'
      hidden_dims: [256, 256]
      activation: silu
      normalize: "false"
      outact: 'none'
      output_dim: ${eval:${reps.n_actions}*${reps.recurrentacf.hidden_dim}}

    pi:
      type: 'mlp'
      hidden_dims: [256, 256]
      activation: silu
      normalize: "false"
      outact: 'none'
      output_dim: ${reps.n_actions}

    encoder_pixel:
      type: residual_encoder
      mlp_layers:
        - 256
        - 256
        - ${reps.recurrentacf.hidden_dim}
      depth: 24
      mlp_activation: silu
      cnn_activation: silu
      min_resolution: 4
      cnn_blocks: 2
      dropout: ${reps.dropout}
      final_activation: 'none'

    decoder_pixel:
      type: residual_decoder
      mlp_layers:
        - ${eval:${reps.recurrentacf.latent_dim}+${reps.recurrentacf.hidden_dim}}
        - 256 
        - 256 
      depth: 24
      mlp_activation: silu
      cnn_activation: silu
      min_resolution: 4
      cnn_blocks: 2

    posterior:
      type: 'mlp'
      hidden_dims: [256, 256]
      activation: silu
      normalize: "false"
      outact: 'tanh'

    params:
        recons_const: 0.
        inverse_const: 1. # factor loss (inverse is a misnomer. leaving for compatibility)
        inverse_model_const: 1. # real inverse loss
        per_action_forward_const: 0.
        forward_const: 1.
        policy_const: 0.

  gcl:
    n_actions: ${reps.n_actions}
    use_action_weights: false
    per_factor: false
    noise_std: 1e-2
    latent_dim: ${reps.latent_dim}
    vars_per_factor: 1
    batch_size: 128
    info_nce: true
    encoder:
      type: 'mlp'
      hidden_dims: [512, 512]
      activation: silu
      normalize: "false"
      outact: 'tanh'
      output_dim: ${reps.gcl.latent_dim}
      dropout: ${reps.dropout}
    
    decoder:
      type: 'mlp'
      hidden_dims: [512, 512]
      activation: silu
      normalize: "false"
      outact: 'none'
    energy:
      type: 'mlp'
      hidden_dims: [128, 128]
      activation: silu
      normalize: rms
      outact: 'none'
      output_dim: ${reps.n_actions}

    encoder_pixel:
      type: residual_encoder
      mlp_layers:
        - 256
        - 256
        - ${reps.gcl.latent_dim}
      depth: 24
      mlp_activation: silu
      cnn_activation: silu
      min_resolution: 4
      cnn_blocks: 2
      dropout: ${reps.dropout}

    decoder_pixel:
      type: residual_decoder
      mlp_layers:
        - ${reps.gcl.latent_dim}
        - 256 
        - 256 
      depth: 24
      mlp_activation: silu
      cnn_activation: silu
      min_resolution: 4
      cnn_blocks: 2
    params:
        recons_const: 0.
        energy_const: 10.

  spr:
    n_actions: ${reps.n_actions}
    latent_dim: ${reps.latent_dim}
    tau: 0.01
    vars_per_factor: 1
    encoder:
      type: mlp
      hidden_dims: [128, 128]
      activation: silu
      normalize: "false"
      outact: tanh
      output_dim: ${reps.spr.latent_dim}
    
    encoder_pixel:
      type: residual_encoder
      mlp_layers:
        - 256
        - 256
        - ${reps.spr.latent_dim}
      depth: 24
      mlp_activation: silu
      cnn_activation: silu
      min_resolution: 4
      cnn_blocks: 2
      output_dim: ${reps.spr.latent_dim}
    
    transition:
      type: mlp
      hidden_dims: [128, 128]
      activation: silu
      normalize: "false"
      outact: none

    projection:
      type: mlp
      hidden_dims: [128, 128]
      activation: silu
      normalize: "false"
      outact: none
      output_dim: 1
    
    predictor:
      type: mlp
      hidden_dims: [128, 128]
      activation: silu
      normalize: "false"
      outact: none
      output_dim: 1
    params: {}

  dreamervae:
    hidden_dim: 256
    latent_dim: ${reps.latent_dim}
    vars_per_factor: 1
    state_dim: ${reps.dreamervae.latent_dim}
    n_actions: ${reps.n_actions}
    pixels: false
    categoricals: ${reps.latent_dim}
    n_values: 32
    type: gaussian 

    obs_embed:
      type: mlp
      hidden_dims: ["${reps.dreamervae.hidden_dim}", "${reps.dreamervae.hidden_dim}"]
      activation: silu
      normalize: rms
      outact: none
      output_dim: ${reps.dreamervae.hidden_dim}
    decoder:
      type: 'mlp'
      hidden_dims: ["${reps.dreamervae.hidden_dim}", "${reps.dreamervae.hidden_dim}"]
      activation: silu
      normalize: rms
      outact: tanh
      input_dim: ${reps.dreamervae.state_dim}
    obs_embed_pixel:
      type: residual_encoder
      mlp_layers: [256, 256, "${reps.dreamervae.hidden_dim}"]
      depth: 24
      mlp_activation: silu
      cnn_activation: silu
      min_resolution: 4
      cnn_blocks: 2
      output_dim: ${reps.dreamervae.hidden_dim}
    decoder_pixel:
      type: residual_decoder
      mlp_layers: ["${reps.dreamervae.decoder_pixel.input_dim}", 256, 256]
      depth: 24
      mlp_activation: silu
      cnn_activation: silu
      min_resolution: 4
      cnn_blocks: 2
      input_dim: 1024
      outact: sigmoid
    
    encoder:
      type: mlp
      hidden_dims: ["${reps.dreamervae.hidden_dim}"]
      activation: silu
      normalize: 'false'
      outact: none
      input_dim: ${reps.dreamervae.hidden_dim}
      output_dim: ${reps.dreamervae.latent_dim}

    embed:
      type: mlp
      hidden_dims: ["${reps.dreamervae.hidden_dim}"]
      output_dim: ${reps.dreamervae.hidden_dim}
      activation: silu
      normalize : rms
      outact: none
    
    dynamics:
      type: mlp
      hidden_dims: ["${reps.dreamervae.hidden_dim}", "${reps.dreamervae.hidden_dim}"]
      activation: silu
      normalize: 'false'
      outact: none
      input_dim: ${eval:${reps.dreamervae.latent_dim}+${reps.dreamervae.hidden_dim}}
      output_dim: ${reps.dreamervae.state_dim}

    params:
      free_nats: 1.
      recons_const: 1.
      dyn_const : 1.
      rep_const: 0.1

envs:
  cartpole:
    env_name: 'gymnax_CartPole-v1'
    n_envs: ${n_envs}
    autoreset: true
    env_params : {}
  mountaincar:
    env_name: 'gymnax_MountainCar-v0'
    n_envs: ${n_envs}
    autoreset: true
    env_params : {}
  asterix:
    env_name: 'gymnax_Asterix-MinAtar'
    n_envs: ${n_envs}
    autoreset: true
    render: true
    env_params : {noise_sigma: 0.01}
  spaceinvaders:
    env_name: 'gymnax_SpaceInvaders-MinAtar'
    n_envs: ${n_envs}
    autoreset: true
    render: true
    env_params : {noise_sigma: 0.01}
  freeway:
    env_name: 'gymnax_Freeway-MinAtar'
    n_envs: ${n_envs}
    autoreset: true
    render: true
    env_params : {noise_sigma: 0.01}
  breakout:
    env_name: 'gymnax_Breakout-MinAtar'
    n_envs: ${n_envs}
    autoreset: true
    render: true
    env_params : {noise_sigma: 0.01}
  pong:
    env_name: 'gymnax_Pong-misc'
    n_envs: ${n_envs}
    autoreset: true
    env_params : {}
  pinball:
    env_name: pinball_pinball
    n_envs: ${n_envs}
    autoreset: true
    env_params : {level: easy}  
  navix:
    env_name: 'navix_FourRooms-v0'
    n_envs: ${n_envs}
    env_params: {max_steps: 500}
    action_space: full
    observation_space: rgb
    autoreset: true
    img_size: ${datasets.img_size}
    noise_sigma: 0.01
  taxi_gymnax:
    env_name: 'taxi_gymnax'
    n_envs: ${n_envs}
    env_params: {max_steps: 500}
    action_space: full
    observation_space: rgb
    autoreset: true
  navix_doorkey:
    env_name: 'navix_DoorKey-8x8-v0'
    n_envs: ${n_envs}
    env_params: {max_steps: 1000}
    action_space: full
    observation_space: rgb
    autoreset: true
    img_size: ${datasets.img_size}

datasets:
  img_size: 32
  grid_2d:
    type: grid_2d
    img_size: ${datasets.img_size}
    thickness: ${thickness}
  multi_object:
    type: multi_object
    n_objects: 1
    img_size: ${datasets.img_size}
  multi_object_selection:
    type: multi_object_selection
    n_objects: 3
    img_size: ${datasets.img_size}
  taxi:
    type: taxi
    img_size: ${datasets.img_size}
    n_passengers: 1
    grid_size: 10
  taxi_suff:
    type: taxi_suff
    img_size: ${datasets.img_size}
    n_passengers: 1
    grid_size: 10
  taxi_gymnax:
    type: taxi_gymnax
    img_size: ${datasets.img_size}
    n_passengers: 1
    grid_size: 10

evaluator_config:
  lr: 1e-4
  batch_size: 128
  n_epochs: 50
  predictor:
    type: mlp
    hidden_dims: [512, 512]
    activation: silu
    normalize: rms
    outact: none

data_collection:
  seed: 0
  policy: 'random'
  n_envs: 1024
  n_samples: 100000

outdir: rl_experiments/identification/
exp_id: reps
device: gpu
eval_max_length : 10000
eval_n_episodes : 10
eval_every : 25000
log_every: 2500
training_steps: 250000
epochs: 50
env: multi_object  # Can be one of: grid_2d, multi_object, multi_object_selection
n_envs: 512
rep: acf

seed: 0
dropout: 0
batch_size: 128
lr: 1e-4
weight_decay: 1e-3
horizon: 8
thickness: 1. # joint space coverage
eval_conditioning: null
use_ground_truth_states: false
prioritized: false
balanced_eval: false

importance_weight_exp:
  start: 0.6   # initial importance-weight exponent (β₀)
  end:   1.0   # final importance-weight exponent (β₁, annealed to this value)
