model:
  _target_: tnp.models.cnp.CNP
  encoder: ${cnp_encoder}
  decoder: ${cnp_decoder}
  likelihood: ${likelihood}

cnp_encoder:
  _target_: tnp.models.cnp.CNPEncoder
  deepset: ${deepset}
  x_encoder: ${x_encoder}

deepset:
  _target_: tnp.networks.deepset.DeepSet
  z_encoder: ${z_encoder}
  agg: mean

x_encoder:
  _target_: tnp.networks.mlp.MLPWithEmbedding
  embeddings:
    - ${x1_encoder}
    - ${x2_encoder}
  in_dim: ${eval:'${x1_encoder.num_wavelengths} + ${x2_encoder.num_wavelengths}'}
  out_dim: ${params.x_embed_dim}
  num_layers: 0
  width: ${params.x_embed_dim}

x1_encoder:
  _target_: tnp.networks.embeddings.FourierEmbedding
  lower: 0.01
  upper: ${eval:'${params.grid_range}[0][1] - ${params.grid_range}[0][0]'}
  assert_range: False
  num_wavelengths: ${params.num_wavelengths}
  active_dim: 0

x2_encoder:
  _target_: tnp.networks.embeddings.FourierEmbedding
  lower: 0.01
  upper: ${eval:'${params.grid_range}[1][1] - ${params.grid_range}[1][0]'}
  assert_range: False
  num_wavelengths: ${params.num_wavelengths}
  active_dim: 1

z_encoder:
  _target_: tnp.networks.mlp.MLP
  in_dim: ${eval:'${params.x_embed_dim} + ${params.dim_y}'}
  out_dim: ${params.embed_dim}
  num_layers: ${params.num_layers}
  width: ${params.embed_dim}

cnp_decoder:
  _target_: tnp.models.tnp.TNPDecoder
  z_decoder: ${z_decoder}

z_decoder:
  _target_: tnp.networks.mlp.MLP
  in_dim: ${eval:'${params.x_embed_dim} + ${params.embed_dim}'}
  out_dim: ${eval:'${params.dim_y} * 2'}
  num_layers: ${params.num_layers}
  width: ${params.embed_dim}

likelihood:
  _target_: tnp.likelihoods.gaussian.HeteroscedasticNormalLikelihood

optimiser:
  _target_: torch.optim.AdamW
  _partial_: True
  lr: 5.0e-4

params:
  epochs: 500
  embed_dim: 128
  num_layers: 5
  x_embed_dim: 128
  num_wavelengths: 64


misc:
  name: CNP-L${params.num_layers}-D${params.embed_dim}
  resume_from_checkpoint: null
  gradient_clip_val: 0.5
  plot_interval: 10
  logging: True
