experiment_name: "Time_Series_Diffusion_Training"
experiment: "diffusion" # other values are gan and metrics

 
defaults:
  - override hydra/launcher: joblib
hydra:
  launcher:
    n_jobs: 2

work_dir: ${hydra:runtime.cwd} 
base_path: ""
save_path: "${experiment_name}/" 
mlflow_folder: "${base_path}/mlflow"

seed: 42
device: "cuda"
num_workers: 8  

dataset_name: "air_quality" 
dataloader_file: general_pl_dataloader 
dataloader_model: TimeweaverDataLoader

denoiser_name: "csdi_timeseries_denoiser_v4" # or sss_timeseries_denoiser_v2
model_file: timeseries_diffusion
model_name: TimeSeriesDiffusionModelTrainer 
model_checkpoint_path: ''

save_key: val_loss
should_compile_torch: False # for SSSD, True for CSDI  
store_intermediate_checkpoints: False

######################### GENERAL DATASET CONFIG ########################################
dataset: 
  train_test_split: 0.8 
  batch_size: 64
  num_workers: 8

######################## SPECIFIC DATASET CONFIG ########################################

waveforms_dataset:
  num_channels: 1
  time_series_length: 96
  required_time_series_length: 64
  log_dir: ""
  num_discrete_labels: 4
  num_discrete_conditions: 1
  num_continuous_labels: 3
  discrete_condition_embedding_dim: 128
  latent_dim: 48

traffic_dataset:
  num_channels: 1
  time_series_length: 96
  required_time_series_length: 64
  log_dir: ""
  num_discrete_labels: 135
  num_discrete_conditions: 7
  num_continuous_labels: 4 
  discrete_condition_embedding_dim: 128
  latent_dim: 48
 
air_quality_dataset:
  num_channels: 6 
  time_series_length: 96
  required_time_series_length: 64
  log_dir: ""
  num_discrete_labels: 101 
  num_discrete_conditions: 6
  num_continuous_labels: 5 
  discrete_condition_embedding_dim: 128

stocks_dataset:
  num_channels: 6 
  time_series_length: 96
  required_time_series_length: 64
  log_dir: ""
  num_discrete_labels: 0 
  num_discrete_conditions: 0
  num_continuous_labels: 0 
  discrete_condition_embedding_dim: 128


########################### TRAINING CONFIG ########################################

training:
    schedule: 'linear' # 'cosine'
    max_epochs: 5000
    learning_rate: 1e-4
    n_plots: 4
    save_after_every_iters: 100
    save_after_every_n_epochs: 50
    auto_lr_find: True
    check_val_every_n_epoch: 10
    log_every_n_steps: 1
    num_devices: 1 # ONLY SUPPORTS 1 DEVICE FOR NOW.
    # strategy: "ddp_find_unused_parameters_true"
    strategy: "auto"

########################### TIME SERIES DENOISER CONFIG ########################################

csdi_timeseries_denoiser_v4_config:
    positional_embedding_dim: 128
    channel_embedding_dim: 16
    channels: 256
    n_heads: 16
    n_layers: 10
    dropout_pos_enc: 0.2  
    use_metadata: True

    metadata_encoder_config:
        use_sa_layer: True
        channels: ${csdi_timeseries_denoiser_v4_config.channels}
        n_heads: 8
        num_encoder_layers: 2
        dropout: 0.1

unet_timeseries_denoiser_v1_config:
    channels: 64
    kernel_size: 7 
    padding: 3
    dim_mults: [1,2,4,8]
    resnet_block_groups: 8
    sinusoidal_pos_emb_theta: 10000
    attn_dim_head: 32
    attn_heads: 4

    metadata_encoder_config:
        use_sa_layer: True
        channels: ${unet_timeseries_denoiser_v1_config.channels}
        n_heads: 8
        num_encoder_layers: 2
        dropout: 0.1

run_name: "${denoiser_name}_${dataset_name}"
