wandb_version: 1

misc:
  desc: null
  value:
    resume_from_checkpoint: null
    logging: true
    seed: 1
    plot_interval: 10
    lightning_eval: true
    num_plots: 10
    gradient_clip_val: 0.5
    only_plots: false
    savefig: true
    subplots: true
    loss_fn:
      _target_: tnp.utils.np_functions.np_loss_fn
      _partial_: true
    pred_fn:
      _target_: tnp.utils.np_functions.np_pred_fn
      _partial_: true
    num_workers: 0
    num_val_workers: 0
    log_interval: 10
    checkpoint_interval: 50
    check_val_every_n_epoch: 10
    project: incTNP-Combined
    name: plain-TNP-RS-LRSched-L5-H8-D128-B16-Ncmax512-LR0.0003
    eval_name: test
generators:
  desc: null
  value:
    train:
      _target_: tnp.data.gp.RandomScaleGPGenerator
      dim: 1
      kernel:
      - ${rbf_kernel}
      - ${matern12_kernel}
      - ${matern32_kernel}
      - ${matern52_kernel}
      - ${periodic_kernel}
      noise_std: 0.1
      min_nc: 1
      max_nc: 512
      min_nt: 128
      max_nt: 128
      context_range:
      - - -2.0
        - 2.0
      target_range:
      - - -2.0
        - 2.0
      samples_per_epoch: 16000
      batch_size: 16
    val:
      _target_: tnp.data.gp.RandomScaleGPGenerator
      dim: 1
      kernel:
      - ${rbf_kernel}
      - ${matern12_kernel}
      - ${matern32_kernel}
      - ${matern52_kernel}
      - ${periodic_kernel}
      noise_std: 0.1
      min_nc: 1
      max_nc: 512
      min_nt: 128
      max_nt: 128
      context_range:
      - - -2.0
        - 2.0
      target_range:
      - - -2.0
        - 2.0
      samples_per_epoch: 4096
      batch_size: 16
      deterministic: true
    test:
      _target_: tnp.data.gp.RandomScaleGPGenerator
      dim: 1
      kernel:
      - ${rbf_kernel}
      - ${matern12_kernel}
      - ${matern32_kernel}
      - ${matern52_kernel}
      - ${periodic_kernel}
      noise_std: 0.1
      min_nc: 1
      max_nc: 512
      min_nt: 128
      max_nt: 128
      context_range:
      - - -2.0
        - 2.0
      target_range:
      - - -2.0
        - 2.0
      samples_per_epoch: 80000
      batch_size: 16
      deterministic: true
rbf_kernel:
  desc: null
  value:
    _target_: tnp.networks.gp.RBFKernel
    _partial_: true
    ard_num_dims: 1
    min_log10_lengthscale: -0.602
    max_log10_lengthscale: 0.0
matern12_kernel:
  desc: null
  value:
    _target_: tnp.networks.gp.MaternKernel
    _partial_: true
    nu: 0.5
    ard_num_dims: 1
    min_log10_lengthscale: -0.602
    max_log10_lengthscale: 0.0
matern32_kernel:
  desc: null
  value:
    _target_: tnp.networks.gp.MaternKernel
    _partial_: true
    nu: 1.5
    ard_num_dims: 1
    min_log10_lengthscale: -0.602
    max_log10_lengthscale: 0.0
matern52_kernel:
  desc: null
  value:
    _target_: tnp.networks.gp.MaternKernel
    _partial_: true
    nu: 2.5
    ard_num_dims: 1
    min_log10_lengthscale: -0.602
    max_log10_lengthscale: 0.0
periodic_kernel:
  desc: null
  value:
    _target_: tnp.networks.gp.PeriodicKernel
    _partial_: true
    ard_num_dims: 1
    min_log10_lengthscale: -0.602
    max_log10_lengthscale: 0.0
    min_log10_period: 0.301
    max_log10_period: 0.301
model:
  desc: null
  value:
    _target_: tnp.models.tnp.TNP
    encoder:
      _target_: tnp.models.tnp.TNPEncoder
      transformer_encoder:
        _target_: tnp.networks.transformer.TNPTransformerEncoder
        mhsa_layer:
          _target_: tnp.networks.attention_layers.MultiHeadSelfAttentionLayer
          embed_dim: 128
          num_heads: 8
          head_dim: 16
          feedforward_dim: 128
          norm_first: true
        mhca_layer:
          _target_: tnp.networks.attention_layers.MultiHeadCrossAttentionLayer
          embed_dim: 128
          num_heads: 8
          head_dim: 16
          feedforward_dim: 128
          norm_first: true
        num_layers: 5
      xy_encoder:
        _target_: tnp.networks.mlp.MLP
        in_dim: 3
        out_dim: 128
        num_layers: 2
        width: 128
    decoder:
      _target_: tnp.models.tnp.TNPDecoder
      z_decoder:
        _target_: tnp.networks.mlp.MLP
        in_dim: 128
        out_dim: 2
        num_layers: 2
        width: 128
    likelihood:
      _target_: tnp.likelihoods.gaussian.HeteroscedasticNormalLikelihood
tnp_encoder:
  desc: null
  value:
    _target_: tnp.models.tnp.TNPEncoder
    transformer_encoder:
      _target_: tnp.networks.transformer.TNPTransformerEncoder
      mhsa_layer:
        _target_: tnp.networks.attention_layers.MultiHeadSelfAttentionLayer
        embed_dim: 128
        num_heads: 8
        head_dim: 16
        feedforward_dim: 128
        norm_first: true
      mhca_layer:
        _target_: tnp.networks.attention_layers.MultiHeadCrossAttentionLayer
        embed_dim: 128
        num_heads: 8
        head_dim: 16
        feedforward_dim: 128
        norm_first: true
      num_layers: 5
    xy_encoder:
      _target_: tnp.networks.mlp.MLP
      in_dim: 3
      out_dim: 128
      num_layers: 2
      width: 128
transformer_encoder:
  desc: null
  value:
    _target_: tnp.networks.transformer.TNPTransformerEncoder
    mhsa_layer:
      _target_: tnp.networks.attention_layers.MultiHeadSelfAttentionLayer
      embed_dim: 128
      num_heads: 8
      head_dim: 16
      feedforward_dim: 128
      norm_first: true
    mhca_layer:
      _target_: tnp.networks.attention_layers.MultiHeadCrossAttentionLayer
      embed_dim: 128
      num_heads: 8
      head_dim: 16
      feedforward_dim: 128
      norm_first: true
    num_layers: 5
mhsa_layer:
  desc: null
  value:
    _target_: tnp.networks.attention_layers.MultiHeadSelfAttentionLayer
    embed_dim: 128
    num_heads: 8
    head_dim: 16
    feedforward_dim: 128
    norm_first: true
mhca_layer:
  desc: null
  value:
    _target_: tnp.networks.attention_layers.MultiHeadCrossAttentionLayer
    embed_dim: 128
    num_heads: 8
    head_dim: 16
    feedforward_dim: 128
    norm_first: true
xy_encoder:
  desc: null
  value:
    _target_: tnp.networks.mlp.MLP
    in_dim: 3
    out_dim: 128
    num_layers: 2
    width: 128
tnp_decoder:
  desc: null
  value:
    _target_: tnp.models.tnp.TNPDecoder
    z_decoder:
      _target_: tnp.networks.mlp.MLP
      in_dim: 128
      out_dim: 2
      num_layers: 2
      width: 128
z_decoder:
  desc: null
  value:
    _target_: tnp.networks.mlp.MLP
    in_dim: 128
    out_dim: 2
    num_layers: 2
    width: 128
likelihood:
  desc: null
  value:
    _target_: tnp.likelihoods.gaussian.HeteroscedasticNormalLikelihood
optimiser:
  desc: null
  value: "AdamW (\nParameter Group 0\n    amsgrad: False\n    betas: [0.9, 0.999]\n\
    \    capturable: False\n    differentiable: False\n    eps: 1e-08\n    foreach:\
    \ None\n    fused: None\n    initial_lr: 0.0003\n    lr: 0.0\n    maximize: False\n\
    \    weight_decay: 0.01\n)"
scheduler:
  desc: null
  value:
    type: warmup_cosine
    warmup:
      steps: null
      fraction: 0.1
    cosine:
      eta_min: 1.0e-06
      T_max: null
params:
  desc: null
  value:
    epochs: 500
    embed_dim: 128
    num_heads: 8
    head_dim: 16
    norm_first: true
    num_layers: 5
    dim_x: 1
    dim_y: 1
    context_range:
    - - -2.0
      - 2.0
    target_range:
    - - -2.0
      - 2.0
    min_log10_lengthscale: -0.602
    max_log10_lengthscale: 0.0
    min_log10_period: 0.301
    max_log10_period: 0.301
    max_nc: 512
    batch_size: 16
_wandb:
  desc: null
  value:
    python_version: 3.10.17
    cli_version: 0.17.0
    framework: lightning
    is_jupyter_run: false
    is_kaggle_kernel: false
    start_time: 1767275260
    t:
      1:
      - 1
      - 9
      - 41
      - 50
      - 55
      - 103
      2:
      - 1
      - 9
      - 41
      - 50
      - 55
      - 103
      3:
      - 7
      - 13
      - 16
      - 23
      - 66
      4: 3.10.17
      5: 0.17.0
      8:
      - 5
      13: linux-x86_64
    m:
    - 1: trainer/global_step
      6:
      - 3
    - 1: performance/between_step_time
      5: 1
      6:
      - 1
    - 1: train/loss_step
      5: 1
      6:
      - 1
    - 1: train/lr_step
      5: 1
      6:
      - 1
    - 1: performance/forward_time
      5: 1
      6:
      - 1
    - 1: performance/backward_time
      5: 1
      6:
      - 1
    - 1: performance/average_updates_per_second
      5: 1
      6:
      - 1
    - 1: performance/last_updates_per_second
      5: 1
      6:
      - 1
    - 1: epoch
      5: 1
      6:
      - 1
    - 1: train/loss_epoch
      5: 1
      6:
      - 1
    - 1: train/lr_epoch
      5: 1
      6:
      - 1
    - 1: val/loglik
      5: 1
      6:
      - 1
    - 1: val/rmse
      5: 1
      6:
      - 1
    - 1: val/gt_loglik
      5: 1
      6:
      - 1
lr_scheduler:
  desc: null
  value: <torch.optim.lr_scheduler.SequentialLR object at 0x71391427ab60>
loss_fn:
  desc: null
  value: null
pred_fn:
  desc: null
  value: null
plot_fn:
  desc: null
  value: null
plot_interval:
  desc: null
  value: 10
