# @package _global_
#
# to execute this experiment run:
# python train.py experiment=tcg

defaults:
  - override /model: per_node_hyper_tcg #hyper_tcg
  - override /datamodule: rna_velocity #linear_velocity
  - override /logger:
      - csv
      - wandb
  - override /trainer: gpu
name: "hyper_per_node_rna_tcg_gfn"

seed: 0

datamodule:
  batch_size: 250

# best
model:
  env_batch_size: 1024
  eval_batch_size: 1000
  full_posterior_eval: False
  uniform_backwards: True
  debug_use_shd_energy: False
  analytic_use_simple_mse_energy: False
  loss_fn: "detailed_balance"
  alpha: 0
  temperature: 0.1
  temper_period: 5
  prior_lambda: 10
  beta: 0.01
  confidence: 0.0
  hidden_dim: 128
  gfn_freq: 10 # may depend on batch-size
  energy_freq: 10 # may depend on batch-size
  load_pretrain: True
  pretraining_epochs: 0
  lr: 5e-6
  hyper: "per_node_mlp"
  hyper_hidden_dim: [64, 64, 64]
  bias: True
  path: <path-to-linsol-model>
  test_mode: True

trainer:
  max_epochs: 1000
  min_epochs: 1000
  check_val_every_n_epoch: 5

logger:
  wandb:
    tags: ["kl", "hyper", "rna", "per-node", "gfn", "${name}", "v_final_3"]
