generators:
  train:
    _target_: tnp.data.gp.RandomScaleGPGenerator
    dim: ${params.dim_x}
    kernel:
      - ${rbf_kernel}
      - ${matern12_kernel}
      - ${matern32_kernel}
      - ${matern52_kernel}
      - ${periodic_kernel}
    noise_std: 0.1
    min_nc: 1
    max_nc: 64
    min_nt: 128
    max_nt: 128
    context_range: ${params.context_range}
    target_range: ${params.target_range}
    samples_per_epoch: 16_000
    batch_size: 16
  val:
    _target_: tnp.data.gp.RandomScaleGPGenerator
    dim: ${params.dim_x}
    kernel:
      - ${rbf_kernel}
      - ${matern12_kernel}
      - ${matern32_kernel}
      - ${matern52_kernel}
      - ${periodic_kernel}
    noise_std: 0.1
    min_nc: 1
    max_nc: 64
    min_nt: 128
    max_nt: 128
    context_range: ${params.context_range}
    target_range: ${params.target_range}
    samples_per_epoch: 4096
    batch_size: 16
    deterministic: True
  test:
    _target_: tnp.data.gp.RandomScaleGPGenerator
    dim: ${params.dim_x}
    kernel:
      - ${rbf_kernel}
      - ${matern12_kernel}
      - ${matern32_kernel}
      - ${matern52_kernel}
      - ${periodic_kernel}
    noise_std: 0.1
    min_nc: 1
    max_nc: 64
    min_nt: 128
    max_nt: 128
    context_range: ${params.context_range}
    target_range: ${params.target_range}
    samples_per_epoch: 4096
    batch_size: 16
    deterministic: True

rbf_kernel:
  _target_: tnp.networks.gp.RBFKernel
  _partial_: True
  ard_num_dims: ${params.dim_x}
  min_log10_lengthscale: ${params.min_log10_lengthscale}
  max_log10_lengthscale: ${params.max_log10_lengthscale}

matern12_kernel:
  _target_: tnp.networks.gp.MaternKernel
  _partial_: True
  nu: 0.5
  ard_num_dims: ${params.dim_x}
  min_log10_lengthscale: ${params.min_log10_lengthscale}
  max_log10_lengthscale: ${params.max_log10_lengthscale}

matern32_kernel:
  _target_: tnp.networks.gp.MaternKernel
  _partial_: True
  nu: 1.5
  ard_num_dims: ${params.dim_x}
  min_log10_lengthscale: ${params.min_log10_lengthscale}
  max_log10_lengthscale: ${params.max_log10_lengthscale}

matern52_kernel:
  _target_: tnp.networks.gp.MaternKernel
  _partial_: True
  nu: 2.5
  ard_num_dims: ${params.dim_x}
  min_log10_lengthscale: ${params.min_log10_lengthscale}
  max_log10_lengthscale: ${params.max_log10_lengthscale}

periodic_kernel:
  _target_: tnp.networks.gp.PeriodicKernel
  _partial_: True
  ard_num_dims: ${params.dim_x}
  min_log10_lengthscale: ${params.min_log10_lengthscale}
  max_log10_lengthscale: ${params.max_log10_lengthscale}
  min_log10_period: ${params.min_log10_period}
  max_log10_period: ${params.max_log10_period}

model:
  _target_: tnp.models.cnp.CNP
  encoder: ${cnp_encoder}
  decoder: ${tnp_decoder}
  likelihood: ${likelihood}

cnp_encoder:
  _target_: tnp.models.cnp.CNPEncoder
  deepset: ${deepset}

deepset:
  _target_: tnp.networks.deepset.DeepSet
  z_encoder: ${z_encoder}
  agg: mean

z_encoder:
  _target_: tnp.networks.mlp.MLP
  in_dim: ${eval:'${params.dim_x} + ${params.dim_y}'}
  out_dim: ${params.embed_dim}
  num_layers: ${params.num_layers}
  width: ${params.embed_dim}

tnp_decoder:
  _target_: tnp.models.tnp.TNPDecoder
  z_decoder: ${z_decoder}

z_decoder:
  _target_: tnp.networks.mlp.MLP
  in_dim: ${eval:'${params.dim_x} + ${params.embed_dim}'}
  out_dim: ${eval:'${params.dim_y} * 2'}
  num_layers: ${params.num_layers}
  width: ${params.embed_dim}

likelihood:
  _target_: tnp.likelihoods.gaussian.HeteroscedasticNormalLikelihood

optimiser:
  _target_: torch.optim.AdamW
  _partial_: True
  lr: 5.0e-4

params:
  # Model + Training Params
  epochs: 200
  embed_dim: 128
  num_heads: 8
  head_dim: 16
  norm_first: True
  num_layers: 5

  # Synthetic 1D GP Params
  dim_x: 1
  dim_y: 1
  context_range: [[-2.0, 2.0]]
  target_range: [[-2.0, 2.0]]
  min_log10_lengthscale: -0.602
  max_log10_lengthscale: 0.0
  min_log10_period: 0.301
  max_log10_period: 0.301


misc:
  project: cnp-combined-rangesame
  name: CNP-L${params.num_layers}-D${params.embed_dim}
  resume_from_checkpoint: null
  gradient_clip_val: 0.5
  plot_interval: 10

  # Plot misc
  eval_name: test
  seed: 1
  only_plots: False
  num_plots: 10
  subplots: True
  savefig: True
  logging: True
  plot_interval: 10
  check_val_every_n_epoch: 1