wandb_version: 1

misc:
  desc: null
  value:
    resume_from_checkpoint: null
    logging: true
    seed: 1
    plot_interval: 10
    lightning_eval: true
    num_plots: 10
    gradient_clip_val: 0.5
    only_plots: false
    savefig: true
    subplots: true
    loss_fn:
      _target_: tnp.utils.np_functions.np_loss_fn
      _partial_: true
    pred_fn:
      _target_: tnp.utils.np_functions.np_pred_fn
      _partial_: true
    num_workers: 0
    num_val_workers: 0
    log_interval: 10
    checkpoint_interval: 100
    check_val_every_n_epoch: 10
    project: incTNP-tab
    name: mask-TNP-LRSched-L5-H8-D128-LR0.005
    eval_name: test
    plot_fn: null
generators:
  desc: null
  value:
    train:
      _target_: tnp.data.tabular_data.TabularDataGeneratorUniqueMLPPerDataset
      dim: 20
      min_nc: 10
      max_nc: 1024
      min_nt: 128
      max_nt: 128
      samples_per_epoch: 32768
      batch_size: 128
    val:
      _target_: tnp.data.tabular_data.TabularDataGeneratorUniqueMLPPerDataset
      dim: 20
      min_nc: 10
      max_nc: 1024
      min_nt: 128
      max_nt: 128
      samples_per_epoch: 8192
      batch_size: 128
      deterministic: true
      deterministic_seed: 1
    test:
      _target_: tnp.data.tabular_data.TabularDataGeneratorUniqueMLPPerDataset
      dim: 20
      min_nc: 10
      max_nc: 1024
      min_nt: 128
      max_nt: 128
      samples_per_epoch: 80000
      batch_size: 128
      deterministic: true
      deterministic_seed: 1
params:
  desc: null
  value:
    dim_x: 20
    dim_y: 1
    max_tokens: 1024
    nt: 128
    epochs: 500
    embed_dim: 128
    num_heads: 8
    head_dim: 16
    norm_first: true
    num_layers: 5
defaults:
  desc: null
  value:
  - /generators/tabular_data
model:
  desc: null
  value:
    _target_: tnp.models.castnp.TNPCausal
    encoder:
      _target_: tnp.models.castnp.TNPEncoderMasked
      transformer_encoder:
        _target_: tnp.networks.transformer.TNPTransformerMaskedEncoder
        mhsa_layer:
          _target_: tnp.networks.attention_layers.MultiHeadSelfAttentionLayer
          embed_dim: 128
          num_heads: 8
          head_dim: 16
          feedforward_dim: 128
          norm_first: true
        mhca_layer:
          _target_: tnp.networks.attention_layers.MultiHeadCrossAttentionLayer
          embed_dim: 128
          num_heads: 8
          head_dim: 16
          feedforward_dim: 128
          norm_first: true
        num_layers: 5
      xy_encoder:
        _target_: tnp.networks.mlp.MLP
        in_dim: 22
        out_dim: 128
        num_layers: 2
        width: 128
    decoder:
      _target_: tnp.models.tnp.TNPDecoder
      z_decoder:
        _target_: tnp.networks.mlp.MLP
        in_dim: 128
        out_dim: 2
        num_layers: 2
        width: 128
    likelihood:
      _target_: tnp.likelihoods.gaussian.HeteroscedasticNormalLikelihood
      min_noise: 0.0001
tnp_encoder:
  desc: null
  value:
    _target_: tnp.models.castnp.TNPEncoderMasked
    transformer_encoder:
      _target_: tnp.networks.transformer.TNPTransformerMaskedEncoder
      mhsa_layer:
        _target_: tnp.networks.attention_layers.MultiHeadSelfAttentionLayer
        embed_dim: 128
        num_heads: 8
        head_dim: 16
        feedforward_dim: 128
        norm_first: true
      mhca_layer:
        _target_: tnp.networks.attention_layers.MultiHeadCrossAttentionLayer
        embed_dim: 128
        num_heads: 8
        head_dim: 16
        feedforward_dim: 128
        norm_first: true
      num_layers: 5
    xy_encoder:
      _target_: tnp.networks.mlp.MLP
      in_dim: 22
      out_dim: 128
      num_layers: 2
      width: 128
transformer_encoder:
  desc: null
  value:
    _target_: tnp.networks.transformer.TNPTransformerMaskedEncoder
    mhsa_layer:
      _target_: tnp.networks.attention_layers.MultiHeadSelfAttentionLayer
      embed_dim: 128
      num_heads: 8
      head_dim: 16
      feedforward_dim: 128
      norm_first: true
    mhca_layer:
      _target_: tnp.networks.attention_layers.MultiHeadCrossAttentionLayer
      embed_dim: 128
      num_heads: 8
      head_dim: 16
      feedforward_dim: 128
      norm_first: true
    num_layers: 5
mhsa_layer:
  desc: null
  value:
    _target_: tnp.networks.attention_layers.MultiHeadSelfAttentionLayer
    embed_dim: 128
    num_heads: 8
    head_dim: 16
    feedforward_dim: 128
    norm_first: true
mhca_layer:
  desc: null
  value:
    _target_: tnp.networks.attention_layers.MultiHeadCrossAttentionLayer
    embed_dim: 128
    num_heads: 8
    head_dim: 16
    feedforward_dim: 128
    norm_first: true
xy_encoder:
  desc: null
  value:
    _target_: tnp.networks.mlp.MLP
    in_dim: 22
    out_dim: 128
    num_layers: 2
    width: 128
tnp_decoder:
  desc: null
  value:
    _target_: tnp.models.tnp.TNPDecoder
    z_decoder:
      _target_: tnp.networks.mlp.MLP
      in_dim: 128
      out_dim: 2
      num_layers: 2
      width: 128
z_decoder:
  desc: null
  value:
    _target_: tnp.networks.mlp.MLP
    in_dim: 128
    out_dim: 2
    num_layers: 2
    width: 128
likelihood:
  desc: null
  value:
    _target_: tnp.likelihoods.gaussian.HeteroscedasticNormalLikelihood
    min_noise: 0.0001
optimiser:
  desc: null
  value: "AdamW (\nParameter Group 0\n    amsgrad: False\n    betas: [0.9, 0.999]\n\
    \    capturable: False\n    differentiable: False\n    eps: 1e-08\n    foreach:\
    \ None\n    fused: None\n    initial_lr: 0.005\n    lr: 0.0\n    maximize: False\n\
    \    weight_decay: 0.01\n)"
scheduler:
  desc: null
  value:
    type: warmup_cosine
    warmup:
      steps: null
      fraction: 0.1
    cosine:
      eta_min: 1.0e-06
      T_max: null
_wandb:
  desc: null
  value:
    python_version: 3.10.17
    cli_version: 0.17.0
    framework: lightning
    is_jupyter_run: false
    is_kaggle_kernel: false
    start_time: 1766412703
    t:
      1:
      - 1
      - 9
      - 41
      - 50
      - 55
      - 103
      2:
      - 1
      - 9
      - 41
      - 50
      - 55
      - 103
      3:
      - 7
      - 13
      - 16
      - 23
      - 66
      4: 3.10.17
      5: 0.17.0
      8:
      - 5
      13: linux-x86_64
    m:
    - 1: trainer/global_step
      6:
      - 3
    - 1: performance/between_step_time
      5: 1
      6:
      - 1
    - 1: train/loss_step
      5: 1
      6:
      - 1
    - 1: train/lr_step
      5: 1
      6:
      - 1
    - 1: performance/forward_time
      5: 1
      6:
      - 1
    - 1: performance/backward_time
      5: 1
      6:
      - 1
    - 1: performance/average_updates_per_second
      5: 1
      6:
      - 1
    - 1: performance/last_updates_per_second
      5: 1
      6:
      - 1
    - 1: epoch
      5: 1
      6:
      - 1
    - 1: train/loss_epoch
      5: 1
      6:
      - 1
    - 1: train/lr_epoch
      5: 1
      6:
      - 1
    - 1: val/loglik
      5: 1
      6:
      - 1
    - 1: val/rmse
      5: 1
      6:
      - 1
lr_scheduler:
  desc: null
  value: <torch.optim.lr_scheduler.SequentialLR object at 0x7ee1ff761de0>
loss_fn:
  desc: null
  value: null
pred_fn:
  desc: null
  value: null
plot_fn:
  desc: null
  value: null
plot_interval:
  desc: null
  value: 10
