experiment_name: "Time_Series_GAN_Training"
experiment: "diffusion" 
 
defaults: 
  - override hydra/launcher: joblib
hydra:
  launcher:
    n_jobs: 2

work_dir: ${hydra:runtime.cwd} 
base_path: ""
save_path: "${experiment_name}/"
mlflow_folder: "${base_path}/mlflow" 


seed: 42
device: "cuda" 
num_workers: 8 

dataset_name: "air_quality"
dataloader_file: general_pl_dataloader 
dataloader_model: TimeweaverDataLoader 

gan_name: "wavegan_v1"
model_file: timeseries_gan
model_name: TimeSeriesGANModelTrainer

save_key: val_loss
should_compile_torch: False # for SSSD, True for CSDI 

autoencoder_checkpoint_path: none 

######################### GENERAL DATASET CONFIG ########################################
dataset: 
  train_test_split: 0.8 
  batch_size: 256
  num_workers: 8

######################## SPECIFIC DATASET CONFIG ########################################
air_quality_dataset: 
  num_channels: 6
  time_series_length: 96
  required_time_series_length: 64
  log_dir: ""
  num_discrete_labels: 101 
  num_discrete_conditions: 6
  num_continuous_labels: 5
  discrete_condition_embedding_dim: 128

traffic_dataset: 
  num_channels: 1
  time_series_length: 96
  required_time_series_length: 64
  log_dir: ""
  num_discrete_labels: 135
  num_discrete_conditions: 7
  num_continuous_labels: 4
  discrete_condition_embedding_dim: 128
  latent_dim: 48

stocks_dataset:
  num_channels: 6 
  time_series_length: 96
  required_time_series_length: 64
  log_dir: ""
  num_discrete_labels: 0 
  num_discrete_conditions: 0
  num_continuous_labels: 0 
  discrete_condition_embedding_dim: 128

waveforms_dataset:
  num_channels: 1
  time_series_length: 96
  required_time_series_length: 64
  log_dir: ""
  num_discrete_labels: 4
  num_discrete_conditions: 1
  num_continuous_labels: 3
  discrete_condition_embedding_dim: 128
  latent_dim: 48

########################### TRAINING CONFIG ########################################

training:
    max_epochs: 5000
    learning_rate: 1e-4
    b1: 0.5
    b2: 0.9
    n_plots: 4
    train_generator_every: 1
    lmbda: 10.0
    save_after_every_iters: 100 
    save_after_every_n_epochs: 100 
    auto_lr_find: True
    check_val_every_n_epoch: 10
    log_every_n_steps: 1
    num_devices: 1 # ONLY SUPPORTS 1 DEVICE FOR NOW.
    strategy: "auto"


########################### TIME SERIES GAN CONFIG ########################################

wavegan_v1_config:
  use_metadata: False
  generator_config:
    kernel_size: 3 
    latent_dim: 48 
    repeat_num: 2
    stride: 2
    blow_up_factor: 10 # 30 for traffic unconditional, 10 for traffic conditional, 5 for air quality unconditional, 2 for air quality conditional, 30 for waveforms, 5 for stocks
    final_activation: None
    metadata_encoder_config:
        channels: 64
        n_heads: 8
        num_encoder_layers: 2
        use_sa_layer: True
        dropout: 0.1

  discriminator_config:
    model_size: 40
    kernel_size: 5
    stride_list: [2, 1, 2, 1, 2, 1]
    alpha: 0.2
    shift_factor: 2
    output_condition: False
    metadata_encoder_config:
        channels: 64
        n_heads: 8
        num_encoder_layers: 2
        use_sa_layer: True
        dropout: 0.1


run_name: "${gan_name}_${dataset_name}"