seq_length: &seqlen 96 
model:
  base_learning_rate: 0.001
  target: ldm.models.diffusion.ddpm_time.LatentDiffusion
  params:
    linear_start: 0.0005
    linear_end: 0.1
    num_timesteps_cond: 1
    log_every_t: 40
    timesteps: 1000
    loss_type: l1
    first_stage_key: "context"
    cond_stage_key: "context"
    seq_len: *seqlen
    channels: 1
    cond_stage_trainable: True
    concat_mode: False
    scale_by_std: False # True
    monitor: 'val/loss_simple_ema'
    conditioning_key: crossattn
    cond_drop_prob: 0.5

    scheduler_config: # 10000 warmup steps
      target: ldm.lr_scheduler.LambdaLinearScheduler
      params:
        warm_up_steps: [1000]
        cycle_lengths: [10000000000000]
        f_start: [1.e-6]
        f_max: [1.]
        f_min: [ 1.]

    unet_config:
      target: ldm.modules.diffusionmodules.ts_unet.UNetModel
      params:
        seq_len: *seqlen
        dims: 1
        in_channels: 1
        out_channels: 1
        model_channels: 64
        attention_resolutions: [ 1, 2, 4]   
        num_res_blocks: 2
        channel_mult: [ 1,2,4,4 ] 
        num_heads: 8
        use_scale_shift_norm: True
        resblock_updown: True
        context_dim: 32
        repre_emb_channels: 32
        latent_unit: 1
        use_spatial_transformer: true
        use_pam: true
        kernel_size: 0
        plot: False

    first_stage_config:  # no first stage model for ts data
      target: ldm.models.autoencoder.IdentityFirstStage 

    cond_stage_config:
      target: ldm.modules.encoders.modules.DomainUnifiedPrototyper  
      params:
        dim: 32
        window: *seqlen
        latent_dim: 32  # 32 * 3
        num_latents: 16
        num_channels: 1

data:
  target: ldm.data.tsg_dataset.TSGDataModule
  params:
    data_path_dict:
      solar: "{DATA_ROOT}/solar_{SEQ_LEN}_train.npy"
      electricity: "{DATA_ROOT}/electricity_{SEQ_LEN}_train.npy"
      traffic: "{DATA_ROOT}/traffic_{SEQ_LEN}_train.npy"
      kddcup: "{DATA_ROOT}/kddcup_{SEQ_LEN}_train.npy"
      taxi: "{DATA_ROOT}/taxi_{SEQ_LEN}_train.npy"
      exchange: "{DATA_ROOT}/exchange_{SEQ_LEN}_train.npy"
      fred_md: "{DATA_ROOT}/fred_md_{SEQ_LEN}_train.npy"
      nn5: "{DATA_ROOT}/nn5_{SEQ_LEN}_train.npy"
      temp: "{DATA_ROOT}/temp_{SEQ_LEN}_train.npy"
      rain: "{DATA_ROOT}/rain_{SEQ_LEN}_train.npy"
      pedestrian: "{DATA_ROOT}/pedestrian_{SEQ_LEN}_train.npy"
      wind_4_seconds: "{DATA_ROOT}/wind_4_seconds_{SEQ_LEN}_train.npy"
    window: *seqlen
    val_portion: 0.1
    batch_size: 256
    num_workers: 8
    normalize: centered_pit
    drop_last: True
    reweight: True

lightning:
  callbacks:
    image_logger:
      target: utils.callback_utils.TSLogger
      params:
        batch_frequency: 2000
        max_images: 8
        increase_log_steps: false


  trainer:
    benchmark: True
    # max_steps: 50000
    max_steps: 50000