_target_: src.models.xz_autoencoder.XZAutoencoder
name: "XZAutoencoder"

defaults:
  # - inference: default
  - discretizer: gumbel
  - collator: simple_text_collator
  - lr_scheduler: reduce_on_plateau # options: 
  - optimizer: default # options: default (Adam), AdamW
  - sequence_to_sequence_model: gpt2_gpt2

special_tokens: ["[pad]", "[bos]", "[eos]", "[unk]"]

modules:
  disc_x: ${model.discretizer}
  disc_z: ${model.discretizer}
  model_x_to_z: ${model.sequence_to_sequence_model}
  config_x_to_z:
    config_encoder:
      _target_: transformers.GPT2Config
      # Maximum sequence length that this model might ever be used with
      vocab_size: 20
      n_embd: 256
      n_positions: ${model.model_params.max_x_length}
      n_layer: 8
      n_head: 4
      activation_function: 'gelu_new'

    config_decoder:
      _target_: transformers.GPT2Config
      # Maximum sequence length that this model might ever be used with
      vocab_size: 20
      n_embd: 256
      n_positions: ${model.model_params.max_z_length}
      n_layer: 8
      n_head: 4
      activation_function: 'gelu_new'
  
  model_z_to_x: ${model.sequence_to_sequence_model}
  config_z_to_x:
    config_encoder:
      _target_: transformers.GPT2Config
      # Maximum sequence length that this model might ever be used with
      vocab_size: 60
      n_embd: 256
      n_positions: ${model.model_params.max_z_length}
      n_layer: 8
      n_head: 4
      activation_function: 'gelu_new'

    config_decoder:
      _target_: transformers.GPT2Config
      # Maximum sequence length that this model might ever be used with
      vocab_size: 60
      n_embd: 256
      n_positions: ${model.model_params.max_x_length}
      n_layer: 8
      n_head: 4
      activation_function: 'gelu_new'

model_params:
  use_pc_grad: False

  # for gradient inner product logging in val step 
  log_gradient_stats: False
  num_steps_log_gradient_stats: 8
  log_gradient_stats_batch_size: 32

  acc_grad_batch: 1
  num_bootstrap_tests: 10

  max_x_length: 100
  max_x_vocab_size: 200
  max_z_length: 100
  max_z_vocab_size: 200
  decode_after_autoreg_step: True

  usexz: True
  usex: True
  usez: True
  
  loss_coeff:
    xzx: 1.0
    zxz: 1.0
    supervised_seperated_x: 1.0
    supervised_seperated_z: 1.0
    quantization_supervised_seperated: 1.0
    quantization_zxz: 1.0
    quantization_xzx: 1.0
  
  use_tokenizer_vocab_len: true
  disc_x_vocab_size: -1
  disc_z_vocab_size: -1
  