data:
  root: null
  total_length: 15
  input_length: 4
  train:
    filename: train.npz.npy
  test:
    filename: test.npz.npy
    n_subsample: 20

model:
  target: dydiff.dydiff_turbulence.DynamicalLDMForTurbulenceWithEncoderCondition
  params:
    image_size: 64
    linear_start: 0.00085
    linear_end: 0.0120
    beta_schedule: cosine
    linear_start_gamma: 0.
    linear_end_gamma: 0.
    gamma_schedule: "cosine-0.5"
    log_every_t: 100
    timesteps: 1000
    first_stage_key: "image"
    first_stage_key_prev: "prev"
    scale_factor: 0.18215
    shift_factor: 0.
    # scale_by_std: True
    channels: none  # align with input_length
    monitor: val/loss_simple_ema
    use_ema: False
    num_timesteps_cond: 1 # add noise in eval for # steps
    parameterization: eps  # v not supported
    x_channels: 2
    z_channels: 3
    new_prev_ema: True
    # frame_weighting: True
    use_x_ema: True

    # concat
    cond_stage_key: "cond"
    conditioning_key: "concat-video-mask-1st"
    cond_stage_config:
      target: torch.nn.Identity

    unconditional_guidance_scale: 1.0
    visualize_intermediates: True
    # rollout: 20

    first_stage_config:
      model:
      # base_learning_rate: 4.5e-6
      target: ldm.models.autoencoder.AutoencoderKL
      params:
        ckpt_path: null
        monitor: "val/rec_loss"
        embed_dim: 3
        lossconfig:
          target: torch.nn.Identity

        ddconfig:
          double_z: True
          z_channels: 3
          resolution: 64
          in_channels: 2
          out_ch: 2
          ch: 128
          ch_mult: [ 1,2,4 ]  # num_down = len(ch_mult)-1
          num_res_blocks: 2
          attn_resolutions: [ ]
          dropout: 0.0

    # unet
    unet_config:
      target: dit.video_dit.DiT_S_4
      params:
        num_video_frames: 0
        input_size: 16
        in_channels: 0
        out_channels: 0

    validate_kwargs:
      ddim: True
      ddim_steps: 50
      ddim_eta: 0.

training:
  max_iterations: 1000005
  model_attrs:
    learning_rate: 1e-4
  batch_size: 16
  num_workers: 16
  validation_freq: null
  accumulate_grad_batches: 1
  logger:
    save_dir: ../logs/turbulence
    logger_freq: 5000
    checkpoint_freq: 10000

eval:
  num_vis: 5