generators:
  train:
    _target_: tnp.data.hadISDTemporal.TemporalHadISDDataGenerator
    split: train
    data_root: ${params.data_root}
    N_c_min: 100
    N_c_max: 2100
    N_t_min: 250
    N_t_max: 250
    samples_per_epoch: 32_000
    batch_size: 32
    delta_hours: ${params.delta_hours}
    h_window: ${params.h_window}
    ordering: ${params.ordering}
  val:
    _target_: tnp.data.hadISDTemporal.TemporalHadISDDataGenerator
    split: val
    data_root: ${params.data_root}
    N_c_min: 100
    N_c_max: 2100
    N_t_min: 250
    N_t_max: 250
    samples_per_epoch: 8_000
    batch_size: 32
    delta_hours: ${params.delta_hours}
    h_window: ${params.h_window}
    ordering: ${params.ordering}
  test:
    _target_: tnp.data.hadISDTemporal.TemporalHadISDDataGenerator
    split: test
    data_root: ${params.data_root}
    N_c_min: 100
    N_c_max: 2100
    N_t_min: 250
    N_t_max: 250
    samples_per_epoch: 80_000
    batch_size: 32
    delta_hours: ${params.delta_hours}
    h_window: ${params.h_window}
    ordering: ${params.ordering}


model:
  _target_: tnp.models.castnp.TNPCausal
  encoder: ${tnp_encoder}
  decoder: ${tnp_decoder}
  likelihood: ${likelihood}

tnp_encoder:
  _target_: tnp.models.castnp.TNPEncoderMasked
  transformer_encoder: ${transformer_encoder}
  xy_encoder: ${xy_encoder}
  x_encoder: ${x_encoder}

transformer_encoder:
  _target_: tnp.networks.transformer.TNPTransformerMaskedEncoder
  mhsa_layer: ${mhsa_layer}
  mhca_layer: ${mhca_layer}
  num_layers: ${params.num_layers}

mhsa_layer:
  _target_: tnp.networks.attention_layers.MultiHeadSelfAttentionLayer
  embed_dim: ${params.embed_dim}
  num_heads: ${params.num_heads}
  head_dim: ${params.head_dim}
  feedforward_dim: ${params.embed_dim}
  norm_first: ${params.norm_first}

mhca_layer:
  _target_: tnp.networks.attention_layers.MultiHeadCrossAttentionLayer
  embed_dim: ${params.embed_dim}
  num_heads: ${params.num_heads}
  head_dim: ${params.head_dim}
  feedforward_dim: ${params.embed_dim}
  norm_first: ${params.norm_first}

xy_encoder:
  _target_: tnp.networks.mlp.MLP
  in_dim: ${eval:'1 + ${params.dim_y} + ${params.lat_embed_dims} + ${params.lon_embed_dims} + ${params.time_embed_dims} + ${params.elev_embed_dims}'}
  out_dim: ${params.embed_dim}
  num_layers: 2
  width: ${params.embed_dim}

x_encoder:
  _target_: tnp.networks.fourier_embed.FourierEmbedderHadISD
  embed_dim_lambdamin_lambda_max:
    - ${params.lat_emb_lmin_lmax}
    - ${params.lon_emb_lmin_lmax}
    - ${params.elev_emb_lmin_lmax}
    - ${params.time_emb_lmin_lmax}

tnp_decoder:
  _target_: tnp.models.tnp.TNPDecoder
  z_decoder: ${z_decoder}

z_decoder:
  _target_: tnp.networks.mlp.MLP
  in_dim: ${params.embed_dim}
  out_dim: ${eval:'2 * ${params.dim_y}'}
  num_layers: 2
  width: ${params.embed_dim}


likelihood:
  _target_: tnp.likelihoods.gaussian.HeteroscedasticNormalLikelihood
optimiser:
  _target_: torch.optim.AdamW
  _partial_: True
  lr: 5.0e-04
  betas: [0.9, 0.999]
  eps: 1.0e-8
  weight_decay: 0.01

scheduler:
  type: "warmup_cosine"
  warmup:
    steps: null
    fraction: 0.1
  cosine:
    eta_min: 1.0e-6
    T_max: null

params:
  # Model + Training Params
  epochs: 200
  embed_dim: 128
  num_heads: 8
  head_dim: 16
  norm_first: True
  num_layers: 5

  # Embedding dimensions for x encoder
  lat_embed_dims: 32
  lon_embed_dims: 32
  time_embed_dims: 32
  elev_embed_dims: 32
  lat_emb_lmin_lmax: # lat is [-1, 1]
    - ${params.lat_embed_dims}
    - 0.001
    - 2.0
  lon_emb_lmin_lmax: # long is [-1, 1]
    - ${params.lon_embed_dims}
    - 0.001
    - 2.0
  time_emb_lmin_lmax: # 8760 = 24 * 365 (i.e hours in a year)
    - ${params.time_embed_dims}
    - 1.0
    - 8760.0
  elev_emb_lmin_lmax: # Elevation is z normalised. for current patch 7.7 sigma is the most extreme elev
    - ${params.elev_embed_dims}
    - 0.1
    - 8.0

  # Fixed constants known
  dim_x: 4
  dim_y: 1

  # Data generation params shared between splits
  delta_hours: 6 # hours between each time step for infilling
  h_window: 8
  ordering: "ctx_time"

  data_root: /REPLACE_WITH_MACHINE_SPECIFIC_PATH_REDACTED_FOR_ANONYMITY

misc:
  project: mask-tnp-hadTime
  name: mask-TNP-L${params.num_layers}-H${params.num_heads}-D${params.embed_dim}
  resume_from_checkpoint: null
  gradient_clip_val: 0.5

  # Plot misc
  eval_name: test_eval
  num_grid_points_plot: 100 # Number of points to use for plot of gridded predictions
  dem_path: "/REPLACE_WITH_MACHINE_SPECIFIC_PATH_REDACTED_FOR_ANONYMITY" # dem file
  cache_dem_dir: "/REPLACE_WITH_MACHINE_SPECIFIC_PATH_REDACTED_FOR_ANONYMITY" # Where to cache dem processing results
  seed: 1
  only_plots: False
  num_plots: 5
  subplots: True
  savefig: True
  logging: True
  plot_interval: 10
  check_val_every_n_epoch: 1
  checkpoint_interval: 20
  num_workers: 5
  num_val_workers: 2