# @package _global_

defaults:
  - override /dataset: cyclone_small
  # - override /dataset: cyclone_single
  - override /training: muon
  - override /autoencoder: vqvae

training:
  n_epochs: 200
  batch_size: 8
  exclude_from_wd: ["cond"]
  gradnorm_balancer: "none"  # "none", "full", "pseudo"
  loss_type: complex_mse
  # integral_loss_type: int_norm_mse

logging:
  compression_ratio: true
  tags: ["v5", "v5_vqvae", "v5_base", "cr10240"]
  name_suffix: "vqvae_cmse_base"

autoencoder:
  latent_dim: 1024
  act_fn: GELU
  decouple_mu: true
  init_weights: "xavier_uniform"  # "xavier_uniform", "truncnormal", "kaiming_uniform"
  patching_init_weights: "xavier_uniform"
  cond_init_weights: "xavier_uniform"  # "kaiming_uniform", "normal_smallvar", "xavier_uniform"
  conditioning: ["itg", "dg", "s_hat", "q"]

  loss_scheduler:
    vq_commit: null
  loss_weights:
    df: 1.0
    vq_commit: 1.0 
  # extra_loss_weights:
  #   flux_int: 1.0
  #   phi_int: 1.0
  #   # Spectral Losses
  #   kxspec: 1.0
  #   kyspec: 1.0
  #   qspec: 1.0
  #   phi_zf: 1.0
  #   mass: 1.0
  #   # Monotonicity constraints (optional)
  #   qspec_monotonicity: 0.0
  #   kyspec_monotonicity: 0.0

  vq:
    codebook_size: 8192
    embedding_dim: 128
    commitment_weight: 0.3
    codebook_type: "euclidean"
    ema_decay: 0.99
    entropy_loss_weight: 0.01
    threshold_ema_dead_code: 2
    rotation_trick: false

  patch:
    patch_size: [ 2, 0, 2, 5, 2 ]
    window_size: [ 8, 0, 4, 9, 8 ]

  vit:
    num_heads: [16]
    depth: [4]
    gradient_checkpoint: true
    use_abs_pe: true
    use_rpb: true
    use_rope: true
    gated_attention: true
    modulation: dit
    mid_norm_learnable: false
    norm_layer: LayerNorm

  bottleneck:
    dim: 128
    use_linear: true
    norm_learnable: true
    num_heads: 2
    depth: 2