model:
  _target_: tnp.models.tnp.TNP
  encoder: ${tnp_encoder}
  decoder: ${tnp_decoder}
  likelihood: ${likelihood}

tnp_encoder:
  _target_: tnp.models.tnp.TNPEncoder
  transformer_encoder: ${lbanp_encoder}
  xy_encoder: ${xy_encoder}

lbanp_encoder:
  _target_: tnp.networks.transformer.PerceiverEncoder
  num_latents: ${params.num_latents}
  mhsa_layer: ${mhsa_layer}
  mhca_ctoq_layer: ${mhca_ctoq_layer}
  mhca_qtot_layer: ${mhca_qtot_layer}
  num_layers: ${params.num_layers}

mhsa_layer:
  _target_: tnp.networks.attention_layers.MultiHeadSelfAttentionLayer
  embed_dim: ${params.embed_dim}
  num_heads: ${params.num_heads}
  head_dim: ${params.head_dim}
  feedforward_dim: ${params.embed_dim}
  norm_first: ${params.norm_first}

mhca_ctoq_layer:
  _target_: tnp.networks.attention_layers.MultiHeadCrossAttentionLayer
  embed_dim: ${params.embed_dim}
  num_heads: ${params.num_heads}
  head_dim: ${params.head_dim}
  feedforward_dim: ${params.embed_dim}
  norm_first: ${params.norm_first}

mhca_qtot_layer:
  _target_: tnp.networks.attention_layers.MultiHeadCrossAttentionLayer
  embed_dim: ${params.embed_dim}
  num_heads: ${params.num_heads}
  head_dim: ${params.head_dim}
  feedforward_dim: ${params.embed_dim}
  norm_first: ${params.norm_first}

xy_encoder:
  _target_: tnp.networks.mlp.MLP
  in_dim: ${eval:'${params.dim_x} + ${params.dim_y} + 1'}
  out_dim: ${params.embed_dim}
  num_layers: 2
  width: ${params.embed_dim}

tnp_decoder:
  _target_: tnp.models.tnp.TNPDecoder
  z_decoder: ${z_decoder}

z_decoder:
  _target_: tnp.networks.mlp.MLP
  in_dim: ${params.embed_dim}
  out_dim: ${eval:'2 * ${params.dim_y}'}
  num_layers: 2
  width: ${params.embed_dim}

likelihood:
  _target_: tnp.likelihoods.gaussian.HeteroscedasticNormalLikelihood

optimiser:
  _target_: torch.optim.AdamW
  _partial_: True
  lr: 5.0e-4

params:
  epochs: 200
  embed_dim: 128
  num_heads: 8
  head_dim: 16
  norm_first: True
  num_layers: 5
  num_latents: 32


misc:
  name: LBANP-L${params.num_layers}-H${params.num_heads}-D${params.embed_dim}-M${params.num_latents}
  resume_from_checkpoint: null
  logging: False
  plot_interval: 10
  gradient_clip_val: 0.5
