
common_args:
  seed: 42
  mask_choice: unconditional

  data_path: ./datasets/mujoco.pt
  timefreq_transform: stft
  seq_len: 24
  n_fft: 11
  hop_length: 3
  overlapping_seqs_stride: 1
  scale_time: true
  decompose: true 
  time_covariates: false
  pad_frames: false

  model:
    name: STDiff

    patch_size: (2,1)
    depth : 6
    hidden_size : 192
    num_heads : 4
    mlp_ratio : 4.0
    dropout : 0.1
    use_freq_bias: true
    use_cov_bias: true
    use_checkpoint_every : 0  # if >0, use checkpointing every N blocks
    alternate_attn_order: false

  gaussian_diffusion:
    timestep_respacing: "1000"
    diffusion_steps: 1000
    predict_xstart: false
    learn_sigma: true
    noise_schedule: "cosine"

  batch_size : 64
  compile: true
  vae_encode: false

train:
  shuffle_tloader: true
  epochs: 1000
  optimizer: adamw
  scheduler:
    name: cos_ann_warmup_restarts
    warmup_steps: 15
    first_cycle_steps: 100  # First cycle step size
    cycle_mult: 2.0
    max_lr: 2.0e-4  # First cycle's max learning rate
    min_lr: 1.0e-6  # Min learning rate
    gamma: 0.5
  weight_decay: 5.0e-3

  save_chkpt_every: 20
  early_stopping_patience: 10
  
  validation_metrics_weight:
    context_fid : 0.25
    cross_correlation : 0.25
    discriminative_score : 0.25
    predictive_score : 0.25

test:
  augment_data: false
  test_proportion: 1.0

  load_ema: true
  ddim: ddim200
  psample_clip_denoised: true

  apply_scale_factor: true
  overlapping_seqs_avg: true
  metrics_seq_len: null # if null, it will be set to the seq_len
  metrics_iterations: 5
  skip_plots: false
  skip_metrics: false