# @package _global_
#
# to execute this experiment run:
# python train.py experiment=tcg

defaults:
  - override /model: per_node_linear_tcg #hyper_tcg
  - override /datamodule: rna_velocity #linear_velocity
  - override /logger:
      - csv
      - wandb
  - override /trainer: gpu
name: "per_node_rna_tcg_gfn"

seed: 0

datamodule:
  batch_size: 500

# best
model:
  env_batch_size: 1024
  eval_batch_size: 1000
  full_posterior_eval: False
  uniform_backwards: True
  debug_use_shd_energy: False
  analytic_use_simple_mse_energy: True
  analytic_use_bayesian_mse_energy: False
  loss_fn: "detailed_balance"
  alpha: 0
  temperature: 0.01
  temper_period: 5
  prior_lambda: 45
  beta: 0.01
  confidence: 0.0
  hidden_dim: 128
  gfn_freq: 1
  energy_freq: 1
  pretraining_epochs: 0
  lr: 5e-5
  hyper: "mlp"
  bias: True
  test_mode: False

trainer:
  max_epochs: 1000
  min_epochs: 1000
  check_val_every_n_epoch: 5

logger:
  wandb:
    tags: ["kl", "analytic", "rna", "per-node", "gfn", "${name}", "v_final_3"]
