# @package _global_

defaults:
  - override /datamodule: cogs.yaml
  - override /model: load_from_pretrained.yaml

# name of the run determines folder name in logs
name: "cogs_sigmae"
run_name: "curriculum-${datamodule.dataset_parameters.supervision_ratio}-${sequence_to_sequence_model_key}-${discretizer_key}" 

sequence_to_sequence_model_key: ???
discretizer_key: ???

track_gradients: Yes
overfit_batch: 0

trainer:
  devices: [0] # 'auto', or numbers like 2, [0]
  accelerator: 'gpu' #cpu, tpu, (devices=4, accelerator="gpu", strategy="ddp"), (devices="auto", accelerator="auto") 
  max_epochs: 1000
  min_epochs: 1000 


datamodule:
  dataset_parameters:
    supervision_ratio: [0.02, 0.9] # [r(xz), r(z|not xz)]
    batch_size: 64
    num_workers: 1
    remove_long_data_points: True
    print_max_lengths: True # uncomment to print max lengths of x and z, given current tokenizing strategy/configs

model:
  checkpoint_path: ???
  substitute_config:
    model_params:
      use_pc_grad: False

      # for gradient inner product logging in val step 
      log_gradient_stats: False
      num_steps_log_gradient_stats: 8
      log_gradient_stats_batch_size: 32

      acc_grad_batch: 1
      num_bootstrap_tests: 10

      max_x_length: 60 # 52 is the max length of a sentence in the dataset
      max_x_vocab_size: 1615 # 1607
      max_z_length: 220 # 218 is the max length of a sentence in the dataset
      max_z_vocab_size: 1510 # 1497

      usexz: True
      usex: False
      usez: False

      loss_coeff:
        # turn them to -1 to disable that loss, (in training, not in validation)
        xzx: 1.0
        zxz: 1.0
        supervised_seperated_x: 1.0
        supervised_seperated_z: 1.0
        quantization_zxz: 0.0
        quantization_xzx: 0.0
        quantization_supervised_seperated: 0.0

    optimizer:
      lr: 0.001

    lr_scheduler:
      mode: "min"
      factor: 0.98
      patience: 1
      threshold: 0.01
      threshold_mode: "abs"
      cooldown: 1
      min_lr: 1e-6
      eps: 1e-8
      verbose: True


callbacks:
  supervision_scheduler:
    scheduler_xz:
      num_warmup_steps: 1 # when to start (epochs)
      num_training_steps: 50 # epoch
      hp_init: 1.0
      hp_end: 0.8
    scheduler_z:
      # between 0 and 1, saying given that a sample is not xz, how likely is it to be z.
      num_warmup_steps: 300
      num_training_steps: 400
      hp_init: 1.0
      hp_end: 1.0


