grid_decoder:
  _target_: tnp.networks.grid_decoders.MHCAGridDecoder
  mhca_layer: ${mhca_layer}
  top_k_ctot: ${params.top_k_ctot}
  roll_dims: ${params.roll_dims}

mhca_layer:
  _target_: tnp.networks.attention_layers.MultiHeadCrossAttentionLayer
  embed_dim: ${params.embed_dim}
  num_heads: ${params.num_heads}
  head_dim: ${params.head_dim}
  feedforward_dim: ${params.embed_dim}
  norm_first: ${params.norm_first}
  linear: ${params.linear_attention}
