# @package _global_

defaults:
  - override /datamodule: cogs
  - override /model: xz_autoencoder
  - override /model/discretizer: ???
  - override /model/sequence_to_sequence_model: bart
  - override /logger: wandb
  - override /trainer: default
# all parameters below will be merged with parameters from default configurations set above
# this allows you to overwrite only specified parameters

# name of the run determines folder name in logs
name: "cogs_sigmae"
run_name: "suponly-${datamodule.dataset_parameters.supervision_ratio}-${model.sequence_to_sequence_model.key}-${model.discretizer.key}" 

track_gradients: Yes
overfit_batch: 0

trainer:
  devices: [0] # 'auto', or numbers like 2, [0]
  accelerator: 'gpu' #cpu, tpu, (devices=4, accelerator="gpu", strategy="ddp"), (devices="auto", accelerator="auto")
  max_epochs: 1000
  min_epochs: 1000

datamodule:
  dataset_parameters:
    supervision_ratio: [0.2, 0.9] # [r(xz), r(z|not xz)]
    batch_size: 256
    num_workers: 8
    remove_long_data_points: True
    print_max_lengths: True # uncomment to print max lengths of x and z, given current tokenizing strategy/configs


model:
  model_params:
    # for gradient inner product logging in val step 
    log_gradient_stats: False
    num_steps_log_gradient_stats: 4
    log_gradient_stats_batch_size: 32

    acc_grad_batch: 8
    num_bootstrap_tests: 10

    max_x_length: 60 # 52 is the max length of a sentence in the dataset
    max_x_vocab_size: 1615 # 1607
    max_z_length: 220 # 218 is the max length of a sentence in the dataset
    max_z_vocab_size: 1510 # 1497

    usexz: True
    usex: False
    usez: False

    loss_coeff:
      xzx: 0.0
      zxz: 0.0
      supervised_seperated_x: 1.0
      supervised_seperated_z: 1.0
      quantization_zxz: 0.0                                                                                                                                                                                                                                                                                                         
      quantization_xzx: 0.0
      quantization_supervised_seperated: 0.0

  optimizer:
    lr: 0.005

  lr_scheduler:
    mode: "min"
    factor: 0.98
    patience: 5
    threshold: 0.01
    threshold_mode: "abs"
    cooldown: 5
    min_lr: 1e-6
    eps: 1e-8
    verbose: True

  collator:
    tokenizer_x: ${model.collator.tokenizers.bpe_tokenizer}
    tokenizer_z: ${model.collator.tokenizers.bpe_tokenizer}

callbacks:
  supervision_scheduler:
    scheduler_xz:
      num_warmup_steps: 1
      num_training_steps: 100
      hp_init: 1.0
      hp_end: 1.0
    scheduler_z:
      # between 0 and 1, saying given that a sample is not xz, how likely is it to be z.
      num_warmup_steps: 1
      num_training_steps: 100
      hp_init: 1.0
      hp_end: 1.0

logger:
  wandb:
    tags: ["supervised-training"]
    notes: