defaults: 
 - base
 - /step: sequential
 - _self_

name: gkt

params:
  _target_: models.galerkin_transformer.model.SimpleTransformer
  n_layers: 4
  d_model: 32                   # input dimension
  memory_augmented: False
  # dim_head: 128              # dimension in each attention head, will be expanded by the kernel_multiplier when computing kernel: d = dim_head * kernel_multiplier
  # latent_dim: ${.d_model}            # the output dimension of the projection operator
  # heads: 4                 # attention heads
  # dim_out: ${.d_model}                # output dimension
  # kernel_multiplier: 3     # use more function bases to computer kernel: k(x_i, x_j)=\sum_{c}^dq_c(x_i)k_c(x_j)    
  # use_rope: True              # use rotary positional encoding or not, by default True
  # scaling_factor: 1        # use scaling factor to modulate the kernel, an example is 1/ sqrt(d) like scaled-dot product attention, by default is: 1
  d_state: 1  
  # norm: True

  n_hidden: ${.d_model}
  num_encoder_layers: ${.n_layers}
  num_regressor_layers: 2
  n_head: 1
  dim_feedforward: ${.d_model}
  layer_norm: True
  attention_type: galerkin
  batch_norm: False
  pos_dim: 1
  xavier_init: 0.001
  diagonal_weight: 0.01
  dropout: 0.0
  ffn_dropout: 0.0

  node_feats: 1
  edge_feats: null
  n_targets: 1
  num_feat_layers: 0
  pred_len: 0
  n_freq_targets: 0
  feat_extract_type: null
  symmetric_init: False
  attn_norm: True
  spacial_residual: False
  return_attn_weight: False
  return_latent: False
  residual_type: plus
  seq_len: null
  bulk_regression: False
  decoder_type: ifft
  freq_dim: 48
  fourier_modes: -1 # all using Fourier modes
  spacial_dim: 1
  spacial_fc: False
  encoder_dropout: 0.0
  decoder_dropout: 0.0
  debug: False

    

# d_state: 128

optimizer:
  _target_: optimizers.setup_s4_optimizer
  lr: 0.001
  weight_decay: 0.0
  # weight_decay: 0.0

batch_size: 32

# # scheduler: None
# scheduler: cosine
scheduler: step
step_size: 200
gamma: 0.5

warmup_epochs: 1