grid_encoder:
  _target_: tnp.networks.grid_encoders.PseudoTokenGridEncoder
  embed_dim: ${params.embed_dim}
  mhca_layer: ${grid_encoder_mhca_layer}
  grid_range: ${params.grid_range}
  points_per_dim: ${params.grid_encoder_grid_shape}

grid_encoder_mhca_layer:
  _target_: tnp.networks.attention_layers.MultiHeadCrossAttentionLayer
  embed_dim: ${params.embed_dim}
  num_heads: ${params.num_heads}
  head_dim: ${params.head_dim}
  feedforward_dim: ${params.embed_dim}
  norm_first: ${params.norm_first}
  linear: ${params.linear_attention}
