
exp_path: "./runs"
seed : 0
wandb:
  use: true
  project: "pinnsem2"
  mode: "online"
mode: "train"

pdes:
  - name: "convection"
    active: true
    config:
      epsilon: 50.0
      num_points_per_dim: 256
      domain: {_literal: {x: [0.0, 6.283185307179586], t: [0.0, 1.0]}}

  - name: sconv
    active: false
    config:
      epsilon: 50.0
      num_points_per_dim: 256
      domain: {_literal: {x: [0.0, 6.283185307179586], t: [0.0, 1.0]}}
  - name: "allencahn"
    active: false
    config:
      fourier_embeddings: 
        scale: 2.0
        dims: 256
      epsilon: 0.0001
      #num_points_per_dim: 256
      #domain: {_literal: {x: [0.0, 1.0], t: [0.0, 1.0]}}
      ref_path: "../data/allencahn/allen_cahn.mat"
      batch_size: 8192



objectives:
  terms:
    ics: mse
    bcs: mse
    res: wls
    #spec: mse
  student_t:
    init: 
      nu:
        res: 2.0
      lam:
        res: 1.0
    update_freq: 1000
    newton_steps: 5
    nu_clip: {min: 20.0, max: 50.0}
    priors: {a_lam: 0.0, b_lam: 0.0, a_nu: 5.0, b_nu: 20.0}
        
models:
  - name: "mlp"
    active: true
    activation: "tanh"
    config:
      hidden_dim: 50
      num_layers: 4
  - name: "mmlp"
    active: false
    activation: "tanh"
    config:
      hidden_dim: 256
      num_layers: 4
      fourier_embeddings:
        scale: 2.0
        dims: 256
      periodicity:
        period: {_literal: [3.141592653589793, ]}  # period in x-axis
        axis: {_literal: [1, ]}  # apply periodicity on x-axis
        trainable: {_literal: [false, ]}  # do not train the period
      reparam:
        type: "weight_fact"
        mean: 1.0
        std: 0.1
init:
  batch_size: 4

optimizers:
  - name: "adam"
    active: true
    config:
      learning_rate: 0.001
      beta1: 0.9
      beta2: 0.999
      eps: 1.0e-8
      scheduler: "none"        
      decay_steps: 5000
      decay_rate: 1.0
      staircase: false
      warmup_steps: 0
      schedule_free: false
      grad_accum_steps: 0             

weighting:
  init_weights:
    ics: 1.0
    bcs: 1.0
    res: 1.0
    #spec: 0.001 #[0.05, 0.07, 0.09, 0.11, 0.15, 0.2]
  momentum: 0.9
  use_causal: false
  scheme: null

training:
  batch_size: 4096
  num_epochs: 3000
  sampler: "fixed"  # "uniform" or "fixed"

logging:
  log_losses: true
  log_weights: false
  log_grads: false
  log_ntk: false
  log_every: 1000
  save_every: 10000
  num_keep_ckpts: 1
  log_stats: true

extra:
  seed_spec: true