# Required global parameters (num_heads can be omitted if not using attention blocks):

latent_dim: 4 # The latent dimension of the VAE, which is the output of the encoder
d_time: 1280 # The dimension of the time embedding (for the residual blocks)
d_context: 768 # The dimension of the context embedding (for the attention blocks)
num_heads: 8 # Number of attention heads for attention blocks
activation: SiLU # Activation function to be used in the residual blocks and at the end of the encoder
groups: 8 # Number of groups for group normalization
dropout: 0.2 # Dropout rate for regularization

# Blocks are templates for layers that can be reused in the encoder and decoder.

blocks:
  simple_conv: # A template for a simple convolutional layer
    type: conv
    out_channels: $out
    kernel_size: 3
    padding: 1
    padding_mode: 'zeros'

  down_conv: # A template for a downsampling convolutional layer
    type: conv
    out_channels: $out
    kernel_size: 3
    stride: 2
    padding: 1
    padding_mode: 'zeros'

  up_conv: # A template for an upsampling convolutional layer
    type: upsample
    out_channels: $out
    scale_factor: 2
    mode: 'bilinear'
    align_corners: True

  ResBlock: # A template for a residual block
    type: res
    skip_channels: $in # Input channels that will be added to the input by the skip connection
    out_channels: $out
    kernel_size: 3
    padding: 1
    padding_mode: 'zeros'
    repeat: $n # Number of times the block is repeated

  AttBlock: # A template for an attention block
    type: att

  ResAndAttBlock: # A template for a residual block followed by an attention block
    type: res_and_att
    skip_channels: $in # Input channels that will be added to the input by the skip connection
    out_channels: $out
    kernel_size: 3
    padding: 1
    padding_mode: 'zeros'
    repeat: $n # Number of times the block is repeated

encoder:
  - use: simple_conv # skip connections (+1) v
    with:
      out: 80

  - use: ResAndAttBlock # skip connections (+2) v
    with:
      in: 0
      out: 80
      n: 2

  - use: down_conv # skip connections (+1) v
    with:
      out: 80

  - use: ResAndAttBlock # skip connections (+2) v
    with:
      in: 0
      out: 160
      n: 2

  - use: down_conv # skip connections (+1) v
    with:
      out: 160

  - use: ResAndAttBlock # skip connections (+2) v
    with:
      in: 0
      out: 320
      n: 2

  - use: down_conv # skip connections (+1) v
    with:
      out: 320

  - use: ResBlock # skip connections (+2) v
    with:
      in: 0
      out: 320
      n: 2

bottleneck:
  - use: ResAndAttBlock
    with:
      in: 0
      out: 320
      n: 1

  - use: ResBlock
    with:
      in: 0
      out: 320
      n: 1

decoder:
  - use: ResBlock # skip connections (-3)
    with:
      in: 320
      out: 320
      n: 3

  - use: up_conv
    with:
      out: 320

  - use: ResAndAttBlock # skip connections (-2)
    with:
      in: 320
      out: 320
      n: 2

  - use: ResAndAttBlock # skip connections (-1)
    with:
      in: 160
      out: 320
      n: 1

  - use: up_conv
    with:
      out: 320

  - use: ResAndAttBlock # skip connections (-2)
    with:
      in: 160
      out: 160
      n: 2

  - use: ResAndAttBlock # skip connections (-1)
    with:
      in: 80
      out: 160
      n: 1

  - use: up_conv
    with:
      out: 160

  - use: ResAndAttBlock # skip connections (-3)
    with:
      in: 80
      out: 80
      n: 3