# @package _global_
defaults:
  - override /model: trellis_gnn_mlp_mfm.yaml
  - override /datamodule: trellis_dataloader.yaml
  - override /logger:
      - csv
      - wandb
  - override /trainer: gpu

hydra:
  launcher:
    name: "mfm_trellis"

seed: 0

datamodule:
  batch_size: 1
  ivp_batch_size: 1024
  split: patients
  num_components: null
  plot_pca: False
  use_small_exp_num: False
  seed: 0

model:
  name: mfm_trellis
  flow_lr: 1e-4
  gnn_lr: 1e-4
  dim: 43
  num_hidden: 512
  num_layers_decoder: 7 
  num_hidden_gnn: 128 
  num_layers_gnn: 2 
  knn_k: 100 
  num_treat_conditions: 11
  num_cell_conditions: 2
  base: source
  pca_space_eval: True # only used of num_components is not null
  run_validation: True
  integrate_time_steps: 500 
  seed: 0

trainer:
  max_epochs: 1500 
  min_epochs: 1500 
  check_val_every_n_epoch: 1500 
  accelerator: gpu
  devices: 1
  #log_every_n_steps: 5

checkpoint:
  filename: "ckpt"

slurm:
  requeue: True

# NOTE: early stopping is NOT being used
callbacks:
  model_checkpoint:
    monitor: "val/2-Wasserstein-PDO" 

  early_stopping:
    monitor: "val/2-Wasserstein-PDO" 
    mode: "min" 
    patience: 100 
    min_delta: 0 

logger:
  wandb:
    tags: ["trellis", "mfm"]