# @package _global_
model:
  name: Causal_CPCPretained
  activation: selu
  pretrain_rep_encoder:
  rep_encoder:                              # Missing hyperparameters are to be filled in command line / with tune_hparams = True / selected with +backbone/crn_hparams=...
    _target_: src.models.causal_cpc_ed.Causal_CPC_Encoder
    br_size: 12
    fc_hidden_units: 12                 # fc_hidden_units <= br_size  used in buiding  treatment  and outcome heads
    genc_hidden: 12                     # hidden dim of local features genc_hidden 
    context_latent_dim: ${model.rep_encoder.br_size}               # dim of features summary 
    downsampling_factor: 1
    subsample_win_ratio: 0.05
    dropout_rate: 0.1              # Dropout of LSTM hidden layers + output layers
    num_layer: 1


    balancing: mutual_info
    alpha_recons: 
    alpha_infonce:
    alpha_mse:
    label_smoothing: 0
    use_spectral_norm: True
    activation: selu
    cpc_lb: 
    infomax_lb: 
    rnn_type: gru
    normalize_rep: False
    random_split: True
    
    dim_random_vitals: 0
    static_features_in_CPC:
    
    batch_size: 32
    optimizer:
      lr_scheduler: False
      non_treatment_head: 
        optimizer_cls: adamw
        learning_rate: 
        weight_decay: 

      treatment_head:
        optimizer_cls: sgd
        learning_rate: 
        momentum: 
        weight_decay: 

    
    use_attention: False

    tune_hparams: False                 # Hparam tuning
    tune_range: 50
    hparams_grid: 
    resources_per_trial:

  train_head:
  
  est_head:                                # Missing hyperparameters are to be filled in command line / with tune_hparams = True / selected with +backbone/crn_hparams=...
    _target_: src.models.causal_cpc_ed.RNNEstHead
    finetune_rep_encoder: 
    retrain_rep_encoder:
    seq_hidden_units: ${model.rep_encoder.br_size}           # rnn_hidden_units in the original terminology should be equal to encoder.br_size
    br_size: ${model.rep_encoder.br_size}                         # preferabvle to be smaller than that of encoder 
    fc_hidden_units:  12                # fc_hidden_units <= br_size  used in buiding  treatment  and outcome heads used only in teacher_forcing
    dropout_rate:  0.1                  # Dropout of LSTM hidden layers + output layers
    num_layer: 1
    batch_size: 64 # 128 cancer sim
    y_dist_type: "continuous"
    teacher_forcing: False
    treat_hidden_dim: 6

    likelihood_training: True
    historical_avg: False
    random_indices:
    percentage_to_keep: 
    rnn_type: gru


    balancing: mutual_info
    label_smoothing: 0
    use_spectral_norm: True
    activation: selu
    alpha_recons: 
    alpha_infonce: 
    alpha_mse:

    step_mse_loss_weights_type: avg                

    optimizer:
      lr_scheduler: False
      non_treatment_head: 
        optimizer_cls: adamw
        learning_rate:
        weight_decay:
        rep_encoder:
          learning_rate:

      treatment_head:
        optimizer_cls: sgd
        learning_rate:
        momentum: 
        weight_decay: 

  
  train_context_decoder: False
  context_decoder: 
    _target_: src.models.causal_cpc_ed.ContextDecoder
    dropout_rate:  0.1
    num_layer: 1
    activation: selu
    enc_hidden_dim: 

    optimizer:
      lr_scheduler: False
      context_decoder: 
        optimizer_cls: adamw
        learning_rate:
        weight_decay:

    tune_hparams: False                   # Hparam tuning
    tune_range: 30
    hparams_grid:
    resources_per_trial:

exp:
  alpha: 
  update_alpha:
  eval_only: True
  weights_ema: False
  
  max_epochs: 
  balancing: ${model.rep_encoder.balancing} 
  rep_encoder: 
    early_stopping: 
      monitor: "val/loss"
      min_delta: 
      patience:
      verbose: False
      mode: "min"

  est_head: 
    early_stopping: 
      monitor: "val/loss"
      min_delta: 
      patience:
      verbose: False
      mode: "min"

  context_decoder:
    early_stopping: 
        monitor: "val/loss"
        min_delta: 
        patience:
        verbose: False
        mode: "min"

