# @package _global_

# +experiment/inference=scan

defaults:
  - override /datamodule: ???
  - override /model: load_from_pretrained.yaml

# name of the run determines folder name in logs
name: "test_${datamodule.key}_sigmae"
run_name: "inference-${training_type}-${sequence_to_sequence_model_key}-${discretizer_key}" 

training_type: ??? #suponly, curriculum, ....
sequence_to_sequence_model_key: ???
discretizer_key: ???

track_gradients: Yes
overfit_batch: 0

trainer:
  devices: [0] # 'auto', or numbers like 2, [0]
  accelerator: 'gpu' #cpu, tpu, (devices=4, accelerator="gpu", strategy="ddp"), (devices="auto", accelerator="auto")  

datamodule:
  dataset_parameters:
    supervision_ratio: [0.02, 0.9] # [r(xz), r(z|not xz)]
    batch_size: 128
    num_workers: 8
    remove_long_data_points: True


model:
  checkpoint_path: ???
  substitute_config:
    model_params:

      use_pc_grad: False

      # for gradient inner product logging in val step 
      log_gradient_stats: False
      num_steps_log_gradient_stats: 8
      log_gradient_stats_batch_size: 32

      
      num_bootstrap_tests: 10    
      max_x_length: 60 # 52 is the max length of a sentence in the dataset
      max_x_vocab_size: 1615 # 1607
      max_z_length: 220 # 218 is the max length of a sentence in the dataset
      max_z_vocab_size: 1510 # 1497
      
      usex: True
      usez: True
      usexz: True
  
      
logger:
  wandb:
    tags: ["supervised-training"]
    notes: