generators:
  train:
    _target_: tnp.data.gp.RandomScaleGPGenerator
    dim: ${params.dim_x}
    kernel:
      - ${rbf_kernel}
      - ${matern12_kernel}
      - ${matern32_kernel}
      - ${matern52_kernel}
      - ${periodic_kernel}
    noise_std: 0.1
    min_nc: 1
    max_nc: 64
    min_nt: 128
    max_nt: 128
    context_range: ${params.context_range}
    target_range: ${params.target_range}
    samples_per_epoch: 16_000
    batch_size: 16
  val:
    _target_: tnp.data.gp.RandomScaleGPGenerator
    dim: ${params.dim_x}
    kernel:
      - ${rbf_kernel}
      - ${matern12_kernel}
      - ${matern32_kernel}
      - ${matern52_kernel}
      - ${periodic_kernel}
    noise_std: 0.1
    min_nc: 1
    max_nc: 64
    min_nt: 128
    max_nt: 128
    context_range: ${params.context_range}
    target_range: ${params.target_range}
    samples_per_epoch: 4096
    batch_size: 16
    deterministic: True
  test:
    _target_: tnp.data.gp.RandomScaleGPGenerator
    dim: ${params.dim_x}
    kernel:
      - ${rbf_kernel}
      - ${matern12_kernel}
      - ${matern32_kernel}
      - ${matern52_kernel}
      - ${periodic_kernel}
    noise_std: 0.1
    min_nc: 1
    max_nc: 64
    min_nt: 128
    max_nt: 128
    context_range: ${params.context_range}
    target_range: ${params.target_range}
    samples_per_epoch: 4096
    batch_size: 16
    deterministic: True

rbf_kernel:
  _target_: tnp.networks.gp.RBFKernel
  _partial_: True
  ard_num_dims: ${params.dim_x}
  min_log10_lengthscale: ${params.min_log10_lengthscale}
  max_log10_lengthscale: ${params.max_log10_lengthscale}

matern12_kernel:
  _target_: tnp.networks.gp.MaternKernel
  _partial_: True
  nu: 0.5
  ard_num_dims: ${params.dim_x}
  min_log10_lengthscale: ${params.min_log10_lengthscale}
  max_log10_lengthscale: ${params.max_log10_lengthscale}

matern32_kernel:
  _target_: tnp.networks.gp.MaternKernel
  _partial_: True
  nu: 1.5
  ard_num_dims: ${params.dim_x}
  min_log10_lengthscale: ${params.min_log10_lengthscale}
  max_log10_lengthscale: ${params.max_log10_lengthscale}

matern52_kernel:
  _target_: tnp.networks.gp.MaternKernel
  _partial_: True
  nu: 2.5
  ard_num_dims: ${params.dim_x}
  min_log10_lengthscale: ${params.min_log10_lengthscale}
  max_log10_lengthscale: ${params.max_log10_lengthscale}

periodic_kernel:
  _target_: tnp.networks.gp.PeriodicKernel
  _partial_: True
  ard_num_dims: ${params.dim_x}
  min_log10_lengthscale: ${params.min_log10_lengthscale}
  max_log10_lengthscale: ${params.max_log10_lengthscale}
  min_log10_period: ${params.min_log10_period}
  max_log10_period: ${params.max_log10_period}

model:
  _target_: tnp.models.convcnp.ConvCNP
  encoder: ${convcnp_encoder}
  decoder: ${tnp_decoder}
  likelihood: ${likelihood}

convcnp_encoder:
  _target_: tnp.models.convcnp.ConvCNPEncoder
  conv_net: ${cnn}
  grid_encoder: ${grid_encoder}
  grid_decoder: ${grid_decoder}
  z_encoder: ${z_encoder}

cnn:
  _target_: tnp.networks.cnn.CNN
  dim: ${params.dim_x}
  num_channels: ${params.num_channels}
  num_blocks: ${params.num_blocks}
  kernel_size: ${params.kernel_size}

grid_encoder:
  _target_: tnp.networks.setconv.SetConvGridEncoder
  dims: ${params.dim_x}
  grid_range: ${params.grid_range}
  grid_shape: ${params.grid_shape}
  init_lengthscale: ${params.init_lengthscale}

grid_decoder:
  _target_: tnp.networks.setconv.SetConvGridDecoder
  dims: ${params.dim_x}

z_encoder:
  _target_: tnp.networks.mlp.MLP
  in_dim: ${eval:'${params.dim_y} + 1'}
  out_dim: ${params.num_channels}
  num_layers: 2
  width: ${params.num_channels}

tnp_decoder:
  _target_: tnp.models.tnp.TNPDecoder
  z_decoder: ${z_decoder}

z_decoder:
  _target_: tnp.networks.mlp.MLP
  in_dim: ${params.num_channels}
  out_dim: ${eval:'2 * ${params.dim_y}'}
  num_layers: 2
  width: ${params.num_channels}

likelihood:
  _target_: tnp.likelihoods.gaussian.HeteroscedasticNormalLikelihood

optimiser:
  _target_: torch.optim.AdamW
  _partial_: True
  lr: 5.0e-4


params:
  # Model + Training Params
  epochs: 200
  num_channels: 32
  num_blocks: 5
  kernel_size: 9
  num_decoder_kernels: 5
  grid_range: ${params.target_range}
  grid_shape: [128,]
  init_lengthscale: 0.1

  # Synthetic 1D GP Params
  dim_x: 1
  dim_y: 1
  context_range: [[-2.0, 2.0]]
  target_range: [[-2.0, 2.0]]
  min_log10_lengthscale: -0.602
  max_log10_lengthscale: 0.0
  min_log10_period: 0.301
  max_log10_period: 0.301


misc:
  project: convcnp-combined-rangesame
  name: ConvCNP-L${params.num_blocks}-C${params.num_channels}-K${params.kernel_size}-GS${params.grid_shape}
  resume_from_checkpoint: null
  gradient_clip_val: 0.5
  plot_interval: 10

  # Plot misc
  eval_name: test
  seed: 1
  only_plots: False
  num_plots: 10
  subplots: True
  savefig: True
  logging: True
  plot_interval: 10
  check_val_every_n_epoch: 1