model:
  _target_: tnp.models.gridded_tnp.MultiModalGriddedTNP
  encoder: ${swintnp_encoder}
  decoder: ${tnp_decoder}
  likelihood: ${likelihood}

swintnp_encoder:
  _target_: tnp.models.gridded_tnp.MultiModalGriddedTNPEncoder
  transformer_encoder: ${transformer_encoder}
  grid_encoder: ${mm_grid_encoder}
  xy_encoder: ${xy_encoder}
  zt_encoder: ${zt_encoder}
  mode_names: ${params.data_vars}
  x_encoder: ${x_encoder}

transformer_encoder:
  _target_: tnp.networks.transformer.GriddedTransformerEncoder
  grid_decoder: ${grid_decoder}
  mhsa_layer: ${mhsa_layer}
  num_layers: ${params.num_layers}
  patch_encoder: ${patch_encoder}

mhsa_layer:
  _target_: tnp.networks.attention_layers.MultiHeadSelfAttentionLayer
  embed_dim: ${params.embed_dim}
  num_heads: ${params.num_heads}
  head_dim: ${params.head_dim}
  feedforward_dim: ${params.embed_dim}
  norm_first: ${params.norm_first}
  linear: ${params.linear_attention}

patch_encoder:
  _target_: tnp.networks.patch_encoders.PatchEncoder
  input_grid: ${params.grid_encoder_grid_shape}
  output_grid: ${params.patch_encoder_grid_shape}
  embed_dim: ${params.embed_dim}
  ignore_dims: [-3,]

mm_grid_encoder:
  _target_: tnp.networks.grid_encoders.MultiModalGridEncoder
  grid_encoder: ${grid_encoder}
  mode_names: ${params.data_vars}
  output_mixer: ${output_mixer}

output_mixer:
  _target_: tnp.networks.mlp.MLP
  in_dim: ${eval:'${params.embed_dim} * len(${params.data_vars})'}
  out_dim: ${params.embed_dim}
  num_layers: 0
  width: ${params.embed_dim}

xy_encoder:
  _target_: tnp.networks.mlp.MLP
  in_dim: ${eval:'1 + ${params.x_embed_dim}'}
  out_dim: ${params.embed_dim}
  num_layers: 2
  width: ${params.embed_dim}

zt_encoder:
  _target_: tnp.networks.mlp.MLP
  in_dim: ${params.x_embed_dim}
  out_dim: ${params.embed_dim}
  num_layers: 2
  width: ${params.embed_dim}

x_encoder:
  _target_: tnp.networks.mlp.MLPWithEmbedding
  embeddings:
    - ${time_embedding}
    - ${latlon_embedding}
  in_dim: ${eval:'${time_embedding.num_wavelengths} + ${latlon_embedding.num_legendre_polys}**2'}
  out_dim: ${params.x_embed_dim}
  num_layers: 0
  width: ${params.x_embed_dim}

time_embedding:
  _target_: tnp.networks.embeddings.FourierEmbedding
  lower: 1.0
  upper: 8766.0
  assert_range: False
  num_wavelengths: ${params.num_wavelengths}
  active_dim: 0

latlon_embedding:
  _target_: tnp.networks.embeddings.SphericalHarmonicsEmbedding
  num_legendre_polys: ${params.num_legendre_polys}
  active_dims: [-2, -1]
  lonlat_dims: [-1, -2]

tnp_decoder:
  _target_: tnp.models.tnp.TNPDecoder
  z_decoder: ${z_decoder}

z_decoder:
  _target_: tnp.networks.mlp.MLP
  in_dim: ${params.embed_dim}
  out_dim: ${eval:'2 * len(${params.data_vars})'}
  num_layers: 2
  width: ${params.embed_dim}

likelihood:
  _target_: tnp.likelihoods.gaussian.HeteroscedasticNormalLikelihood

optimiser:
  _target_: torch.optim.AdamW
  _partial_: True
  lr: 5.0e-4

params:
  epochs: 500
  embed_dim: 128
  num_heads: 8
  head_dim: 16
  norm_first: True
  num_layers: 5
  linear_attention: False
  patch_encoder_grid_shape: [12, 30]
  grid_encoder_grid_shape: [24, 60]
  top_k_ctot: 9
  roll_dims: null
  x_embed_dim: 128
  num_legendre_polys: 10
  num_wavelengths: 10
  use_nn: True
  dim_distance: 2
  dist_fn:
    _target_: tnp.utils.distances.dist_composition
    _partial_: True
    dist_fns:
      - ${sq_dist}
      - ${haversine_dist}
    dist_fn_dims:
      - [0,]
      - [-2, -1]

haversine_dist:
    _target_: tnp.utils.distances.haversine_dist
    _partial_: True
    lonlat_dims: [-1, -2]

sq_dist:
  _target_: tnp.utils.distances.sq_dist
  _partial_: True


misc:
  name: MM-VITTNP-PE-GS1${params.grid_encoder_grid_shape}-GS2${params.patch_encoder_grid_shape}
  resume_from_checkpoint: null
  gradient_clip_val: 0.5
  plot_interval: 10
  logging: True
  pl_logging: True
