model:
  base_learning_rate: 5.e-5
  target: vwm.models.diffusion.DiffusionEngine
  params:
    use_ema: True
    input_key: img_seq
    scale_factor: 0.18215
    disable_first_stage_autocast: True
    en_and_decode_n_samples_a_time: 1
    num_frames: &num_frames 25
    slow_spatial_layers: False
    train_peft_adapters: True
    replace_cond_frames: &replace_cond_frames True
    fixed_cond_frames: # only used for logging images
      - [ 0 ]

    denoiser_config:
      target: vwm.modules.diffusionmodules.denoiser.Denoiser
      params:
        num_frames: *num_frames

        scaling_config:
          target: vwm.modules.diffusionmodules.denoiser_scaling.VScalingWithEDMcNoise

    network_config:
      target: vwm.modules.diffusionmodules.video_model.VideoUNet
      params:
        adm_in_channels: 768
        num_classes: sequential
        use_checkpoint: True
        in_channels: 8
        out_channels: 4
        model_channels: 320
        attention_resolutions: [ 4, 2, 1 ]
        num_res_blocks: 2
        channel_mult: [ 1, 2, 4, 4 ]
        num_head_channels: 64
        use_linear_in_transformer: True
        transformer_depth: 1
        context_dim: 1024
        spatial_transformer_attn_type: softmax-xformers
        extra_ff_mix_layer: True
        use_spatial_context: True
        merge_strategy: learned_with_images
        video_kernel_size: [ 3, 1, 1 ]

    conditioner_config:
      target: vwm.modules.GeneralConditioner
      params:
        emb_models:
          - input_key: cond_frames_without_noise
            is_trainable: False
            ucg_rate: 0.15
            target: vwm.modules.encoders.modules.FrozenOpenCLIPImagePredictionEmbedder
            params:
              n_cond_frames: 1
              n_copies: 1
              open_clip_embedding_config:
                target: vwm.modules.encoders.modules.FrozenOpenCLIPImageEmbedder
                params:
                  freeze: True

          - input_key: fps_id
            is_trainable: False
            ucg_rate: 0.0
            target: vwm.modules.encoders.modules.ConcatTimestepEmbedderND
            params:
              outdim: 256

          - input_key: motion_bucket_id
            is_trainable: False
            ucg_rate: 0.0
            target: vwm.modules.encoders.modules.ConcatTimestepEmbedderND
            params:
              outdim: 256

          - input_key: cond_frames
            is_trainable: False
            ucg_rate: 0.15
            target: vwm.modules.encoders.modules.VideoPredictionEmbedderWithEncoder
            params:
              disable_encoder_autocast: True
              n_cond_frames: 1
              n_copies: 1
              is_ae: True

              encoder_config:
                target: vwm.models.autoencoder.AutoencoderKLModeOnly
                params:
                  embed_dim: 4
                  monitor: val/rec_loss

                  ddconfig:
                    attn_type: vanilla-xformers
                    double_z: True
                    z_channels: 4
                    resolution: 256
                    in_channels: 3
                    out_ch: 3
                    ch: 128
                    ch_mult: [ 1, 2, 4, 4 ]
                    num_res_blocks: 2
                    attn_resolutions: [ ]
                    dropout: 0.0

                  loss_config:
                    target: torch.nn.Identity

          - input_key: cond_aug
            is_trainable: False
            ucg_rate: 0.0
            target: vwm.modules.encoders.modules.ConcatTimestepEmbedderND
            params:
              outdim: 256

          - input_key: command
            is_trainable: False
            ucg_rate: 0.15
            target: vwm.modules.encoders.modules.ConcatTimestepEmbedderND
            params:
              outdim: &action_emb_dim 128
              num_features: 1
              add_sequence_dim: True

          - input_key: trajectory
            is_trainable: False
            ucg_rate: 0.15
            target: vwm.modules.encoders.modules.ConcatTimestepEmbedderND
            params:
              outdim: *action_emb_dim
              num_features: 8
              add_sequence_dim: True

          - input_key: speed
            is_trainable: False
            ucg_rate: 0.15
            target: vwm.modules.encoders.modules.ConcatTimestepEmbedderND
            params:
              outdim: *action_emb_dim
              num_features: 4
              add_sequence_dim: True

          - input_key: angle
            is_trainable: False
            ucg_rate: 0.15
            target: vwm.modules.encoders.modules.ConcatTimestepEmbedderND
            params:
              outdim: *action_emb_dim
              num_features: 4
              add_sequence_dim: True

          - input_key: goal
            is_trainable: False
            ucg_rate: 0.15
            target: vwm.modules.encoders.modules.ConcatTimestepEmbedderND
            params:
              outdim: *action_emb_dim
              num_features: 2
              add_sequence_dim: True

    first_stage_config:
      target: vwm.models.autoencoder.AutoencodingEngine
      params:
        loss_config:
          target: torch.nn.Identity

        regularizer_config:
          target: vwm.modules.autoencoding.regularizers.DiagonalGaussianRegularizer

        encoder_config:
          target: vwm.modules.diffusionmodules.model.Encoder
          params:
            attn_type: vanilla
            double_z: True
            z_channels: 4
            resolution: 256
            in_channels: 3
            out_ch: 3
            ch: 128
            ch_mult: [ 1, 2, 4, 4 ]
            num_res_blocks: 2
            attn_resolutions: [ ]
            dropout: 0.0

        decoder_config:
          target: vwm.modules.autoencoding.temporal_ae.VideoDecoder
          params:
            attn_type: vanilla
            double_z: True
            z_channels: 4
            resolution: 256
            in_channels: 3
            out_ch: 3
            ch: 128
            ch_mult: [ 1, 2, 4, 4 ]
            num_res_blocks: 2
            attn_resolutions: [ ]
            dropout: 0.0
            video_kernel_size: [ 3, 1, 1 ]

    scheduler_config:
      target: vwm.lr_scheduler.LambdaLinearScheduler
      params:
        warm_up_steps: [ 1000 ]
        cycle_lengths: [ 10000000000000 ]
        f_start: [ 1.e-6 ]
        f_max: [ 1. ]
        f_min: [ 1. ]

    loss_fn_config:
      target: vwm.modules.diffusionmodules.loss.StandardDiffusionLoss
      params:
        use_additional_loss: True
        offset_noise_level: 0.02
        additional_loss_weight: 0.1
        num_frames: *num_frames
        replace_cond_frames: *replace_cond_frames
        cond_frames_choices:
          - [ ]
          - [ 0 ]
          - [ 0, 1 ]
          - [ 0, 1, 2 ]

        sigma_sampler_config:
          target: vwm.modules.diffusionmodules.sigma_sampling.EDMSampling
          params:
            p_mean: 1.0
            p_std: 1.6
            num_frames: *num_frames

        loss_weighting_config:
          target: vwm.modules.diffusionmodules.loss_weighting.VWeighting

    sampler_config:
      target: vwm.modules.diffusionmodules.sampling.EulerEDMSampler
      params:
        num_steps: 15

        discretization_config:
          target: vwm.modules.diffusionmodules.discretizer.EDMDiscretization
          params:
            sigma_max: 700.0

        guider_config:
          target: vwm.modules.diffusionmodules.guiders.LinearPredictionGuider
          params:
            num_frames: *num_frames
            max_scale: 3.0
            min_scale: 1.5

data:
  target: vwm.data.dataset.Sampler
  params:
    batch_size: 1
    num_workers: 16
    subsets:
      - YouTube
      - NuScenes
    probs:
      - 1
      - 1
    samples_per_epoch: 256000
    target_height: 576
    target_width: 1024
    num_frames: *num_frames

lightning:
  callbacks:
    image_logger:
      target: main.ImageLogger
      params:
        num_frames: *num_frames
        disabled: False
        enable_autocast: False
        batch_frequency: 1000
        increase_log_steps: True
        log_first_step: False
        log_images_kwargs:
          N: *num_frames

  modelcheckpoint:
    params:
      every_n_epochs: 1  # every_n_train_steps: 5000, set the same as image_logger batch_frequency

  trainer:
    devices: 0,1
    benchmark: True
    num_sanity_val_steps: 0
    accumulate_grad_batches: 1
    max_epochs: 100
    strategy: deepspeed_stage_2
    gradient_clip_val: 0.3