seq_length: &seqlen 96 
model:
  base_learning_rate: &lr 0.001
  target: ldm.models.improved_diffusion.improved_diffusion.LatentDiffusion
  params:
    linear_start: 0.0005
    linear_end: 0.1
    num_timesteps_cond: 1
    log_every_t: 40
    timesteps: 200
    loss_type: l1
    first_stage_key: "context"
    cond_stage_key: "context"
    seq_len: *seqlen
    channels: 1
    cond_stage_trainable: True
    concat_mode: False
    scale_by_std: False # True
    monitor: 'val/loss_simple_ema'
    conditioning_key: crossattn
    cond_drop_prob: 0.5
    model_var_type: 3

    scheduler_config: # 10000 warmup steps
      target: ldm.lr_scheduler.LambdaLinearScheduler
      params:
        warm_up_steps: [1000]
        cycle_lengths: [10000000000000]
        f_start: [1.e-6]
        f_max: [1.]
        f_min: [ 1.]

    unet_config:
      target: ldm.modules.diffusionmodules.DiT_def_ks.DiT
      params:
        seq_len: *seqlen
        patch_size: 1
        in_channels: 1
        hidden_size: 512
        depth: 6
        num_heads: 16
        mlp_ratio: 4.0
        learn_sigma: False
        context_dim: 32
        latent_unit: 1
        use_pam: True
        kernel_size: 0
        plot: False
        def_kernel_size: []
        kernel_type: 'dct_conv'

    first_stage_config:  # no first stage model for ts data
      target: ldm.models.autoencoder.IdentityFirstStage 

    cond_stage_config:
      target: ldm.modules.encoders.modules.DomainUnifiedPrototyper  
      params:
        dim: 32
        window: *seqlen
        latent_dim: 32  # 32 * 3
        num_latents: 1
        num_channels: 1

data:
  target: ldm.data.tsg_dataset.TSGDataModule
  params:
    data_path_dict:
      solar: "{DATA_ROOT}/solar_{SEQ_LEN}_train.npy"
      electricity: "{DATA_ROOT}/electricity_{SEQ_LEN}_train.npy"
      traffic: "{DATA_ROOT}/traffic_{SEQ_LEN}_train.npy"
      kddcup: "{DATA_ROOT}/kddcup_{SEQ_LEN}_train.npy"
      taxi: "{DATA_ROOT}/taxi_{SEQ_LEN}_train.npy"
      exchange: "{DATA_ROOT}/exchange_{SEQ_LEN}_train.npy"
      fred_md: "{DATA_ROOT}/fred_md_{SEQ_LEN}_train.npy"
      nn5: "{DATA_ROOT}/nn5_{SEQ_LEN}_train.npy"
      temp: "{DATA_ROOT}/temp_{SEQ_LEN}_train.npy"
      rain: "{DATA_ROOT}/rain_{SEQ_LEN}_train.npy"
      pedestrian: "{DATA_ROOT}/pedestrian_{SEQ_LEN}_train.npy"
      wind_4_seconds: "{DATA_ROOT}/wind_4_seconds_{SEQ_LEN}_train.npy"
    window: *seqlen
    val_portion: 0.1
    batch_size: 512
    num_workers: 8
    normalize: centered_pit
    drop_last: True
    reweight: True

lightning:
  callbacks:
    image_logger:
      target: utils.callback_utils.TSLogger
      params:
        batch_frequency: 10000
        max_images: 8
        increase_log_steps: false


  trainer:
    benchmark: True
    max_steps: 50000