selected_model: 'MultiConv'

models:
  MultiConv:
    parameters:
      in_shape: [10, 2, 360, 720]  # Adjusted to match data shape (T, C, H=lat, W=lon)

  SimVP:
    parameters:
      shape_in: [10, 2, 240, 300]
      shape_out: [10, 2, 240, 300]
      hid_S: 32
      hid_T: 64
      N_S: 4
      N_T: 8

  TurbL1:
    parameters:
      shape_in: [10, 2, 240, 300]
      spatial_hidden_dim: 256
      output_channels: 2
      temporal_hidden_dim: 512
      num_spatial_layers: 4
      num_temporal_layers: 8
  
  Triton:
    parameters:
      shape_in: [10, 2, 240, 300]
      spatial_hidden_dim: 256
      output_channels: 2
      temporal_hidden_dim: 512
      num_spatial_layers: 4
      num_temporal_layers: 8
  # Add more models as needed, e.g.
  # UNet:
  #   parameters:
  #     in_channels: 2
  #     out_channels: 2
  #     hidden_channels: 64

trainings:
  MultiConv:
    batch_size: 1
    num_epochs: 100
    learning_rate: 0.001
    lr_step_size: 10
    lr_gamma: 0.2
    seed: 42
    parallel_method: 'DistributedDataParallel'

  SimVP:
    batch_size: 1
    num_epochs: 100
    learning_rate: 0.001
    lr_step_size: 10
    lr_gamma: 0.2
    seed: 42
    parallel_method: 'DistributedDataParallel'

  Triton:
    batch_size: 1
    num_epochs: 200
    learning_rate: 0.001
    lr_step_size: 40
    lr_gamma: 0.2
    seed: 42
    parallel_method: 'DistributedDataParallel'

  TurbL1:
    batch_size: 1
    num_epochs: 100
    learning_rate: 0.001
    lr_step_size: 10
    lr_gamma: 0.2
    seed: 42
    parallel_method: 'DistributedDataParallel'

  # Add for other models

datas:
  MultiConv:
    data_path: './data'  # Adjusted to original path

  SimVP:
    data_path: './data'

  TurbL1:
    data_path: './data'

  Triton:
    data_path: './data'

loggings:
  MultiConv:
    backbone: 'Global_ssv_multi_new_uv_0.25'  # Adjusted to match original checkpoint name
    log_dir: './logs'
    checkpoint_dir: './checkpoints'  # Adjusted to original
    result_dir: './'  # Saves uv_pred.nc to current dir as in original

  SimVP:
    backbone: 'SimVP'
    log_dir: './logs'
    checkpoint_dir: './checkpoints/pretrained'
    result_dir: './data/GS'

  TurbL1:
    backbone: 'TurbL1'
    log_dir: './logs'
    checkpoint_dir: './checkpoints/pretrained'
    result_dir: './data/GS'

  Triton:
    backbone: 'Triton'
    log_dir: './logs'
    checkpoint_dir: './checkpoints/pretrained'
    result_dir: './data/GS'