model:
  _target_: tnp.models.convcnp.MultiModalConvCNP
  encoder: ${convcnp_encoder}
  decoder: ${tnp_decoder}
  likelihood: ${likelihood}

convcnp_encoder:
  _target_: tnp.models.convcnp.MultiModalConvCNPEncoder
  conv_net: ${cnn}
  grid_encoder: ${setconv_encoder}
  grid_decoder: ${setconv_decoder}
  z_encoder: ${z_encoder}
  mode_names: ${params.data_vars}

setconv_encoder:
  _target_: tnp.networks.grid_encoders.SetConvThroughTime
  dims: ${params.dim_distance}
  points_per_dim: ${params.grid_encoder_grid_shape}
  grid_range: ${params.grid_range}
  dist_fn: ${params.dist_fn}
  use_nn: ${params.use_nn}
  time_dim: 0

setconv_decoder:
  _target_: tnp.networks.grid_decoders.SetConvGridDecoder
  dim: ${params.dim_distance}
  dist_fn: ${params.dist_fn}
  top_k_ctot: ${params.top_k_ctot}

z_encoder:
  _target_: tnp.networks.mlp.MLP
  in_dim: ${eval:'2 * len(${params.data_vars})'}
  out_dim: ${params.num_channels}
  num_layers: 2
  width: ${params.num_channels}

tnp_decoder:
  _target_: tnp.models.tnp.TNPDecoder
  z_decoder: ${z_decoder}

z_decoder:
  _target_: tnp.networks.mlp.MLP
  in_dim: ${params.num_channels}
  out_dim: ${eval:'2 * len(${params.data_vars})'}
  num_layers: 2
  width: ${params.num_channels}

likelihood:
  _target_: tnp.likelihoods.gaussian.HeteroscedasticNormalLikelihood
  min_noise: 1.0e-3

optimiser:
  _target_: torch.optim.AdamW
  _partial_: True
  lr: 5.0e-4

params:
  epochs: 300
  grid_encoder_grid_shape: [48, 120]
  use_nn: True
  top_k_ctot: 27
  dim_distance: 2
  dist_fn:
    _target_: tnp.utils.distances.dist_composition
    _partial_: True
    dist_fns:
      - ${sq_dist}
      - ${haversine_dist}
    dist_fn_dims:
      - [0,]
      - [-2, -1]

haversine_dist:
    _target_: tnp.utils.distances.haversine_dist
    _partial_: True
    lonlat_dims: [-1, -2]

sq_dist:
  _target_: tnp.utils.distances.sq_dist
  _partial_: True


misc:
  name: MM-ConvCNPXL-GS${params.grid_encoder_grid_shape}-L${params.num_blocks}-C${params.num_channels}-K${params.kernel_size}
  resume_from_checkpoint: null
  gradient_clip_val: 0.5
  plot_interval: 10
  logging: True
  pl_logging: True
