_name_: hedgehog   # Defaults for a 64-dim causal model (2-layer solves 65k assoc. recall w/ vocab size 40, should be ~44936 params)
# model_dim: 64
conv_config:       # Computes Q, K, V from input sequence, but factors in local dependence (instead of just projection as in Transformers)
  method: shift
  kwargs:
    n_kernels: 8   # Number of SSMs; same SSM is applied to multiple dims like in S5; each dim is allocated to a diff head (see n_heads, head_dim)
    kernel_dim: 2  # Dimension of A, B, C matrices in SSM
    n_heads: 1     # Number of heads per SSM or kernel
    head_dim: 8    # Head dim (for computing dot products b/t heads)
    model_dim: 64
    kernel_weights: null
    kernel_init: normal
    kernel_train: true
    skip_connection: true
    norm_order: 0
    lr: ${eval:"min(0.001, ${optimizer.lr})"}
attention_config:
  method: hedgehog
  kwargs:
    n_kernels: 8  # n_kernels, n_heads, head_dim, model_dim 
    n_heads: 1    # should be same as above
    head_dim: 8
    model_dim: 64
    bidirectional: false
    attention: spiked_relu_lse  # or spiked_relu_diff, spiked_relu_lse
    affine_qkv: true
    temperature_qkv: 1.
    context_len: null
    dropout: 0.2
    layernorm: false
    skip_connection: true
    linear_bias: false
    output_method: mixer
    output_mix_head_pattern: nk (nh hd)
    output_mix_kernel_pattern: nh hd nk
    qkv_mlp_kwargs:
      input_dim: 64   # same as model_dim
      output_dim: 64  # same as model_dim
      hidden_dim: 32  # bottleneck, heuristic is model_dim // 2
      activation: gelu
      bias: true
      dropout: 0.2
      layernorm: false
      n_layers: 2
      n_activations: 1
      pre_activation: false
ffn_config: None
# ffn_config:
#   method: mlp
#   kwargs:
#     # input_dim, output_dim should be same as model_dim above
#     input_dim: 64
#     output_dim: 64
#     activation: gelu
#     dropout: 0.2
#     layernorm: false
#     n_layers: 2
#     n_activations: 1
#     pre_activation: false
#     input_shape: bld
#     skip_connection: true
#     average_pool: null