
exp_path: "./runs3"
seed : 0
wandb:
  use: true
  project: "pinnsem3"
  mode: "online"
mode: "train"
#eval_dir: "./runs3/rt_mmlp_soap_tanh/seed_0-6"

pdes:
  - name: "rt"
    config:
      epsilon: 3.0
      num_points_per_dim: 256
      #domain: {_literal: {x: [0.0, 1.0], t: [0.0, 1.0]}}
      #domain: {_literal: {x: [0.0, 1.0], t: [0.0, 1.0]}}
      ref_path: "../data/rayleigh_taylor/rayleigh_taylor.npy"
      batch_size: 8192
      

objectives:
  terms:
    uics: mse
    vics: mse
    pics: mse
    Tics: mse
    ubcs: mse
    vbcs: mse
    Tbcs: mse
    resu: mse
    resv: mse
    resc: mse
    rese: mse
  student_t:
    init: 
      nu:
        res: 5 #[ 50, 45, 40, 35, 30, 25, 20, 15, 10, 5]
      lam:
        res: 0.0002 #[0.002, 0.02, 0.08, 0.5, 0.15, 0.4, 1.0]
    update_freq: 1000
    newton_steps: 5
    nu_clip: {min: 20.0, max: 50.0}
    priors: {a_lam: 0.0, b_lam: 0.0, a_nu: 5.0, b_nu: 20.0}
        
models:
  - name: "mlp"
    active: true
    activation: "tanh"
    input_dim: 3
    config:
      hidden_dim: 256
      num_layers: 3
      out_dim: 4


  - name: "mmlp"
    active: false
    activation: "swish"
    input_dim: 3
    config:
      out_dim: 4
      hidden_dim: 256
      num_layers: 3
      fourier_embeddings:
        scale: 2.0
        dims: 256
      periodicity:
        period: {_literal: [6.283185307179586, ]}  # period in x-axis
        axis: {_literal: [1, ]}  # apply periodicity on x-axis
        trainable: {_literal: [false, ]}  # do not train the period
      reparam:
        type: "weight_fact"
        mean: 1.0
        std: 0.1
init:
  batch_size: 4

optimizers:
  - name: "adamw"
    active: false
    config:
      learning_rate: 0.001
      beta1: 0.9
      beta2: 0.999
      eps: 1.0e-8
      scheduler: "exponential"        
      decay_steps: 5000
      decay_rate: 0.9
      staircase: false
      warmup_steps: 5000
      schedule_free: false
      grad_accum_steps: 0
      # regularizer:
      weight_decay: 0.0005 #[0.01, 0.001, 0.0005, 0.00001]  # weight decay coefficient
      #reg_type: l2 #"l2"  # "l1" or "l2" 

  - name: "adam"
    active: true
    config:
      learning_rate: 0.001
      beta1: 0.9
      beta2: 0.99
      eps: 1.0e-8
      scheduler: "exponential"        
      decay_steps: 2000
      decay_rate: 0.9
      staircase: false
      warmup_steps: 2000
      schedule_free: false
      grad_accum_steps: 0   
  - name: "soap"
    active: false
    config:
      learning_rate: 0.001
      beta1: 0.9
      beta2: 0.999
      eps: 1.0e-8
      scheduler: "exponential"        
      decay_steps: 5000
      decay_rate: 0.9
      staircase: false
      warmup_steps: 5000
      schedule_free: true
      grad_accum_steps: 0  
  - name: "sgd"
    active: false
    config:
      learning_rate: 0.00001
      momentum: 0.0
      scheduler: null #"exponential"        
      decay_steps: 2000
      decay_rate: 0.9
      staircase: false
      warmup_steps: 2000
      schedule_free: true
      grad_accum_steps: 0          

weighting:
  init_weights:
    uics: 1.0
    vics: 1.0
    pics: 1.0
    Tics: 1.0
    ubcs: 1.0
    vbcs: 1.0
    Tbcs: 1.0
    resu: 1.0
    resv: 1.0
    resc: 1.0
    rese: 1.0
  momentum: 0.9
  use_causal: false
  scheme: null # groupdro
  update_freq: 1000

training:
  batch_size: 8192
  num_epochs: 500000
  sampler: "uniform"  # "uniform" or "fixed"
  num_time_windows: 4

logging:
  log_losses: true
  log_weights: true
  log_grads: true
  log_ntk: false
  log_every: 1000
  save_every: 10000
  num_keep_ckpts: 20
  log_stats: true

rotation:
  rot_rho: 0.99
  rot_precond_freq: 2
  rot_eps: 1.0e-6
  rot_max_dim: 512

hetero: false

subspace:
  enabled: false 
  rank : 8
  oja_lr: 0.05
  normalize_grads: True
  log_matrices : False
  seed : 0
reg:
  lam: 0.0001

extra:
  seed_spec: true