common_args:
  data_path: ./datasets/ETTh.csv
  seed: 42
  
  mask_choice: unconditional
  timefreq_transform: stft
  # sl 128, nfft 32, hl 8 -> num_frames = 17, n_freq=17 
  seq_len: 128
  n_fft: 63
  hop_length: 15
  overlapping_seqs_stride: 1
  scale_time: true
  decompose: true 
  time_covariates: false
  pad_frames: false

  model:
    name: STDiff

    patch_size: (8,1)
    depth : 8
    hidden_size : 384
    num_heads : 6
    mlp_ratio : 4.0
    dropout : 0.0
    use_freq_bias: true
    use_cov_bias: true
    use_checkpoint_every : 0  # if >0, use checkpointing every N blocks
    alternate_attn_order: true

  gaussian_diffusion:
    timestep_respacing: "1000"
    diffusion_steps: 1000
    predict_xstart: false
    learn_sigma: true
    noise_schedule: "cosine"

  batch_size : 128
  compile: true
  vae_encode: false

train:
  shuffle_tloader: true
  epochs: 1000
  optimizer: adamw
  scheduler:
    name: cos_ann_warmup_restarts
    warmup_steps: 15
    first_cycle_steps: 100  # First cycle step size
    cycle_mult: 2.0
    max_lr: 2.0e-4  # First cycle's max learning rate
    min_lr: 1.0e-6  # Min learning rate
    gamma: 0.5
  weight_decay: 5.0e-3

  save_chkpt_every: 10
  early_stopping_patience: 10
  
  validation_metrics_weight:
    context_fid : 0.25
    cross_correlation : 0.25
    discriminative_score : 0.25
    predictive_score : 0.25


test:
  n_samples_to_generate: 3000
  load_ema: true
  ddim: ddim200
  psample_clip_denoised: true
  apply_scale_factor: true
  overlapping_seqs_avg: true
  metrics_seq_len: null # if null, it will be set to the seq_len
  metrics_iterations: 5
  skip_plots: false
  skip_metrics: false