# @package _global_
dataset:
  max_number: 1000

model:
  name: Causal_CPCPretained
  pretrain_rep_encoder: True

  rep_encoder:                              
    br_size: 14
    fc_hidden_units: 12                 # fc_hidden_units <= br_size  used in buiding  treatment  and outcome heads
    genc_hidden: 14                     # hidden dim of local features genc_hidden 
    alpha_recons: 1
    alpha_infonce: 1
    alpha_mse: 10

    cpc_lb: infonce 
    infomax_lb: infonce
    rnn_type: gru

    static_features_in_CPC: False

    batch_size: 256
    optimizer:
      non_treatment_head: 
        learning_rate: 0.001 # instead of 0.005
        weight_decay: 0.0000
        momentum: 0.0000

      treatment_head:
        learning_rate: 0.001 # instead of 0.005
        momentum: 0.9
        weight_decay: 0.0000

  train_head: True
  est_head:                                # Missing hyperparameters are to be filled in command line / with tune_hparams = True / selected with +backbone/crn_hparams=...
    finetune_rep_encoder: True
    retrain_rep_encoder: False
    fc_hidden_units:  12 
    batch_size: 128 # 128 cancer sim

    random_indices: True
    percentage_to_keep: 0.1
    rnn_type: gru

    alpha_recons: 0
    alpha_infonce: 0
    alpha_mse: 10 # 10 not used when likelihood_training

    optimizer:
      non_treatment_head: 
        learning_rate: 0.001 # instead of 0.005
        weight_decay: 0.0000
        rep_encoder:
          learning_rate: 0.0005

      treatment_head:
        learning_rate: 0.01  # instead of 0.005
        momentum: 0.9
        weight_decay: 0.0000

  train_context_decoder: False
  context_decoder: 
    activation: selu
    batch_size: 32
    enc_hidden_dim: ${model.rep_encoder.context_latent_dim} 

    optimizer:
      lr_scheduler: False
      context_decoder: 
        optimizer_cls: adamw
        learning_rate: 0.0005
        weight_decay: 0.0000

    tune_hparams: False                   # Hparam tuning
    tune_range: 30
    hparams_grid:
    resources_per_trial:
        
exp:
  alpha: 1
  update_alpha: False
  max_epochs: 2000

  rep_encoder: 
    early_stopping: 
      min_delta : 0.001
      patience: 100
  
  est_head: 
    early_stopping: 
      min_delta : 0.001
      patience: 50

  context_decoder: 
    early_stopping: 
      min_delta : 0.001
      patience: 50

