# @package _global_
#
# to execute this experiment run:
# python train.py experiment=tcg

defaults:
  - override /model: per_node_hyper_tcg #hyper_tcg
  - override /datamodule: linear_unidentifiable_velocity #linear_velocity
  - override /logger:
      - csv
      - wandb
  - override /trainer: gpu
name: "hyper_per_node_linear_tcg_gfn"

seed: 0
# for experiments seed=13,29,42,73,91

datamodule:
  batch_size: 100 #500
  T: 2
  p: 20
  vars_to_deidentify: [0, 1, 2]
  sparsity: 0.9 # 0.9 --> 1024 Nodes for p=20 and [0,1,2]
  system: "linear"
  sigma: 0
  seed: 13

# best
model:
  env_batch_size: 256
  eval_batch_size: 2500
  full_posterior_eval: False
  uniform_backwards: True
  debug_use_shd_energy: False
  analytic_use_simple_mse_energy: False
  loss_fn: "detailed_balance"
  alpha: 0
  temperature: 0.005
  temper_period: 5
  prior_lambda: 100 # 400
  beta: 0.01
  confidence: 0.0
  hidden_dim: 128
  n_steps: 0 # always set to zero in this version
  gfn_freq: 5 # may depend on batch-size
  energy_freq: 5 # may depend on batch-size
  load_pretrain: True
  pretraining_epochs: 0
  lr: 1e-5
  hyper: "per_node_mlp"
  hyper_hidden_dim: [64, 64, 64]
  bias: True
  path: <path-to-linsol-model>

trainer:
  max_epochs: 1000
  min_epochs: 1000
  check_val_every_n_epoch: 5

logger:
  wandb:
    tags: ["kl", "hyper", "linear", "per-node", "gfn", "${name}", "v_hyp_fix"]
