_target_: src.models.from_blocks_library.XZAutoencoder
name: "XZAutoencoderFromBlocks"

defaults:
  # - inference: default
  - collator: simple_text_collator
  - lr_scheduler: reduce_on_plateau # options: 
  - optimizer: default # options: default (Adam), AdamW

special_tokens: ${model.init_models.hydra_configs.special_tokens}
discretizer_key: "softmax"

init_models:
  _target_: src.models.modules.wrapped_models.UntrainedBart
  hydra_configs:
    special_tokens: ["[pad]", "[bos]", "[eos]", "[unk]"]
    autoreg_wrapper_config:
      use_past_key_values: False
      use_last_step_states: True
      max_lengths: 
        input: ${model.model_params.max_z_length}
        output: ${model.model_params.max_z_length}
      soft_average: 
        p_eos_backward: True
        p_eos_forward: False
        word_embeds_with_scores_forward: True

    config_x_to_z:
      _target_: transformers.BartConfig
      vocab_size: 3
      max_position_embeddings: ${model.model_params.max_z_length}
      encoder_layers: 8
      encoder_ffn_dim: 4096
      encoder_attention_heads: 4
      decoder_layers: 8
      decoder_ffn_dim: 4096
      decoder_attention_heads: 4
      d_model: 1024
      use_cache: True
      
    config_z_to_x:
      _target_: transformers.BartConfig
      vocab_size: 3
      max_position_embeddings: ${model.model_params.max_z_length}
      encoder_layers: 8
      encoder_ffn_dim: 4096
      encoder_attention_heads: 4
      decoder_layers: 8
      decoder_ffn_dim: 4096
      decoder_attention_heads: 4
      d_model: 1024
      use_cache: True

    disc_x:
      _target_: blocks.modules.discrete_bottleneck.softmax.SoftmaxDiscreteBottleneck
    disc_x_config: 
      quantize_vector: True 
      temperature: 5.0
      encoder_embedding_trainable: True
      decoder_embedding_trainable: True
      linear_head_trainable: True

    disc_z:
      _target_: blocks.modules.discrete_bottleneck.softmax.SoftmaxDiscreteBottleneck
    disc_z_config: 
      quantize_vector: True 
      temperature: 5.0
      encoder_embedding_trainable: True
      decoder_embedding_trainable: True
      linear_head_trainable: True


model_params:
  use_pc_grad: False

  # for gradient inner product logging in val step 
  log_gradient_stats: False
  num_steps_log_gradient_stats: 8
  log_gradient_stats_batch_size: 32

  acc_grad_batch: 1
  num_bootstrap_tests: 10

  max_x_length: 100
  max_x_vocab_size: 200
  max_z_length: 100
  max_z_vocab_size: 200
  decode_after_autoreg_step: True

  usexz: True
  usex: True
  usez: True
  
  loss_coeff:
    xzx: 1.0
    zxz: 1.0
    supervised_seperated_x: 1.0
    supervised_seperated_z: 1.0
    quantization_supervised_seperated: 1.0
    quantization_zxz: 1.0
    quantization_xzx: 1.0
  
  use_tokenizer_vocab_len: true
  disc_x_vocab_size: -1
  disc_z_vocab_size: -1
  