defaults:
  - /generators/tabular_data

model:
  _target_: tnp.models.tnpa.TNPA
  encoder: ${tnp_encoder}
  decoder: ${tnp_decoder}
  likelihood: ${likelihood}
  no_samples_rollout_pred: ${params.num_samples_pred}

tnp_encoder:
  _target_: tnp.models.tnpa.ARTNPEncoder
  transformer_encoder: ${transformer_encoder}
  xy_encoder: ${xy_encoder}

transformer_encoder:
  _target_: tnp.networks.transformer.TransformerEncoder
  mhsa_layer: ${mhsa_layer}
  num_layers: ${params.num_layers}

mhsa_layer:
  _target_: tnp.networks.attention_layers.MultiHeadSelfAttentionLayer
  embed_dim: ${params.embed_dim}
  num_heads: ${params.num_heads}
  head_dim: ${params.head_dim}
  feedforward_dim: ${params.embed_dim}
  norm_first: ${params.norm_first}

xy_encoder:
  _target_: tnp.networks.mlp.MLP
  in_dim: ${eval:'1 + ${params.dim_y} + ${params.dim_x}'}
  out_dim: ${params.embed_dim}
  num_layers: 2
  width: ${params.embed_dim}

tnp_decoder:
  _target_: tnp.models.tnp.TNPDecoder
  z_decoder: ${z_decoder}

z_decoder:
  _target_: tnp.networks.mlp.MLP
  in_dim: ${params.embed_dim}
  out_dim: ${eval:'2 * ${params.dim_y}'}
  num_layers: 2
  width: ${params.embed_dim}

likelihood:
  _target_: tnp.likelihoods.gaussian.HeteroscedasticNormalLikelihood
  min_noise: 1.0e-04

optimiser:
  _target_: torch.optim.AdamW
  _partial_: True
  lr: 5.0e-04
  betas: [0.9, 0.999] # default
  eps: 1.0e-8 # default
  weight_decay: 0.01 # default

scheduler:
  type: "warmup_cosine"  # Options: "constant", "warmup", "cosine", "warmup_cosine"
  warmup:
    steps: null          # Specific number of warmup steps (overrides fraction)
    fraction: 0.1       # Fraction of total training steps for warmup
  cosine:
    eta_min: 1.0e-6      # Minimum learning rate for cosine annealing
    T_max: null          # Max steps for cosine (null = auto-calculate)

params:
  # Model + Training Params
  epochs: 680
  embed_dim: 128
  num_heads: 8
  head_dim: 16
  norm_first: True
  num_layers: 5
  num_latents: 128

misc:
  project: incTNP-tab
  name: TNPA-LRSched-L${params.num_layers}-H${params.num_heads}-D${params.embed_dim}-LR${optimiser.lr}-E${params.epochs}
  resume_from_checkpoint: null
  gradient_clip_val: 0.5
  plot_interval: 10

  # Plot misc
  eval_name: test
  seed: 1
  only_plots: False
  num_plots: 10
  subplots: True
  savefig: True
  logging: True
  check_val_every_n_epoch: 10
  checkpoint_interval: 100
  num_workers: 0
  num_val_workers: 0
  plot_fn: null
