experiment_name: "Constrained_Time_Series_Diffusion_Training"
experiment: "diffusion" # other values are gan and metrics

 
defaults:
  - override hydra/launcher: joblib
hydra:
  launcher:
    n_jobs: 2

work_dir: ${hydra:runtime.cwd} 
base_path: ""
save_path: "${experiment_name}/" 
mlflow_folder: "${base_path}/mlflow"

seed: 42
device: "cuda"
num_workers: 8  

dataset_name: "air_quality" 
dataloader_file: general_pl_dataloader 
dataloader_model: TimeweaverDataLoader

denoiser_name: "csdi_timeseries_denoiser_v5" # or sss_timeseries_denoiser_v2
model_file: constrained_timeseries_diffusion
model_name: ConstrainedTimeSeriesDiffusionModelTrainer 
model_checkpoint_path: ''

save_key: val_loss
should_compile_torch: False # for SSSD, True for CSDI  
store_intermediate_checkpoints: False

######################### GENERAL DATASET CONFIG ########################################
dataset: 
  train_test_split: 0.8 
  batch_size: 64
  num_workers: 8

######################## SPECIFIC DATASET CONFIG ########################################

waveforms_dataset:
  num_channels: 1
  time_series_length: 96
  required_time_series_length: 64
  log_dir: ""
  num_discrete_labels: 4
  num_discrete_conditions: 1
  num_continuous_labels: 3
  discrete_condition_embedding_dim: 128
  latent_dim: 48

traffic_dataset:
  num_channels: 1
  time_series_length: 96
  required_time_series_length: 64
  log_dir: ""
  num_discrete_labels: 0
  num_discrete_conditions: 0
  num_continuous_labels: 0
  discrete_condition_embedding_dim: 128
  latent_dim: 48
  equality_constraints_to_extract: ['argmax', 'max and argmax', 'argmin', 'min and argmin', 'mean', 'mean change', 'val_at_1', 'val_at_24', 'val_at_48', 'val_at_72', 'val_at_96']

 
air_quality_dataset:
  num_channels: 6 
  time_series_length: 96
  required_time_series_length: 64
  log_dir: ""
  num_discrete_labels: 0 
  num_discrete_conditions: 0
  num_continuous_labels: 0 
  discrete_condition_embedding_dim: 128
  equality_constraints_to_extract: ['argmax', 'max and argmax', 'argmin', 'min and argmin', 'mean', 'mean change', 'val_at_1', 'val_at_24', 'val_at_48', 'val_at_72', 'val_at_96']


stocks_dataset:
  num_channels: 6 
  time_series_length: 96
  required_time_series_length: 64
  log_dir: ""
  num_discrete_labels: 0 
  num_discrete_conditions: 0
  num_continuous_labels: 0 
  discrete_condition_embedding_dim: 128
  equality_constraints_to_extract: ['argmax', 'max and argmax', 'argmin', 'min and argmin', 'mean', 'mean change', 'val_at_1', 'val_at_24', 'val_at_48', 'val_at_72', 'val_at_96']

########################### TRAINING CONFIG ########################################

training:
    schedule: 'linear' # 'cosine'
    max_epochs: 5000
    learning_rate: 1e-4
    n_plots: 4
    save_after_every_iters: 100
    save_after_every_n_epochs: 50
    auto_lr_find: True
    check_val_every_n_epoch: 10
    log_every_n_steps: 1
    num_devices: 1 # ONLY SUPPORTS 1 DEVICE FOR NOW.
    # strategy: "ddp_find_unused_parameters_true"
    strategy: "auto"

########################### TIME SERIES DENOISER CONFIG ########################################
csdi_timeseries_denoiser_v5_config:
    positional_embedding_dim: 128
    channel_embedding_dim: 16
    channels: 256
    n_heads: 16
    n_layers: 10
    dropout_pos_enc: 0.2 
    use_metadata: False
      
    metadata_encoder_config:
        use_sa_layer: True
        channels: ${csdi_timeseries_denoiser_v5_config.channels}
        n_heads: 8
        num_encoder_layers: 2
        dropout: 0.1

run_name: "${denoiser_name}_${dataset_name}"
