model:
  target: vwm.models.diffusion.DiffusionEngine
  params:
    use_ema: True
    input_key: img_seq
    scale_factor: 0.18215
    disable_first_stage_autocast: True
    en_and_decode_n_samples_a_time: 14
    num_frames: &num_frames 25

    denoiser_config:
      target: vwm.modules.diffusionmodules.denoiser.Denoiser
      params:
        num_frames: *num_frames

        scaling_config:
          target: vwm.modules.diffusionmodules.denoiser_scaling.VScalingWithEDMcNoise

    network_config:
      target: vwm.modules.diffusionmodules.video_model.VideoUNet
      params:
        adm_in_channels: 768
        num_classes: sequential
        use_checkpoint: False
        in_channels: 8
        out_channels: 4
        model_channels: 320
        attention_resolutions: [ 4, 2, 1 ]
        num_res_blocks: 2
        channel_mult: [ 1, 2, 4, 4 ]
        num_head_channels: 64
        use_linear_in_transformer: True
        transformer_depth: 1
        context_dim: 1024
        spatial_transformer_attn_type: softmax-xformers
        extra_ff_mix_layer: True
        use_spatial_context: True
        merge_strategy: learned_with_images
        video_kernel_size: [ 3, 1, 1 ]

    conditioner_config:
      target: vwm.modules.GeneralConditioner
      params:
        emb_models:
          - input_key: cond_frames_without_noise
            is_trainable: False
            target: vwm.modules.encoders.modules.FrozenOpenCLIPImagePredictionEmbedder
            params:
              n_cond_frames: 1
              n_copies: 1
              open_clip_embedding_config:
                target: vwm.modules.encoders.modules.FrozenOpenCLIPImageEmbedder
                params:
                  freeze: True

          - input_key: fps_id
            is_trainable: False
            target: vwm.modules.encoders.modules.ConcatTimestepEmbedderND
            params:
              outdim: 256

          - input_key: motion_bucket_id
            is_trainable: False
            target: vwm.modules.encoders.modules.ConcatTimestepEmbedderND
            params:
              outdim: 256

          - input_key: cond_frames
            is_trainable: False
            target: vwm.modules.encoders.modules.VideoPredictionEmbedderWithEncoder
            params:
              disable_encoder_autocast: True
              n_cond_frames: 1
              n_copies: 1
              is_ae: True

              encoder_config:
                target: vwm.models.autoencoder.AutoencoderKLModeOnly
                params:
                  embed_dim: 4
                  monitor: val/rec_loss

                  ddconfig:
                    attn_type: vanilla-xformers
                    double_z: True
                    z_channels: 4
                    resolution: 256
                    in_channels: 3
                    out_ch: 3
                    ch: 128
                    ch_mult: [ 1, 2, 4, 4 ]
                    num_res_blocks: 2
                    attn_resolutions: [ ]
                    dropout: 0.0

                  loss_config:
                    target: torch.nn.Identity

          - input_key: cond_aug
            is_trainable: False
            target: vwm.modules.encoders.modules.ConcatTimestepEmbedderND
            params:
              outdim: 256

          - input_key: command
            is_trainable: False
            target: vwm.modules.encoders.modules.ConcatTimestepEmbedderND
            params:
              outdim: &action_emb_dim 128
              num_features: 1
              add_sequence_dim: True

          - input_key: trajectory
            is_trainable: False
            target: vwm.modules.encoders.modules.ConcatTimestepEmbedderND
            params:
              outdim: *action_emb_dim
              num_features: 8
              add_sequence_dim: True

          - input_key: speed
            is_trainable: False
            target: vwm.modules.encoders.modules.ConcatTimestepEmbedderND
            params:
              outdim: *action_emb_dim
              num_features: 4
              add_sequence_dim: True

          - input_key: angle
            is_trainable: False
            target: vwm.modules.encoders.modules.ConcatTimestepEmbedderND
            params:
              outdim: *action_emb_dim
              num_features: 4
              add_sequence_dim: True

          - input_key: goal
            is_trainable: False
            target: vwm.modules.encoders.modules.ConcatTimestepEmbedderND
            params:
              outdim: *action_emb_dim
              num_features: 2
              add_sequence_dim: True

    first_stage_config:
      target: vwm.models.autoencoder.AutoencodingEngine
      params:
        loss_config:
          target: torch.nn.Identity

        regularizer_config:
          target: vwm.modules.autoencoding.regularizers.DiagonalGaussianRegularizer

        encoder_config:
          target: vwm.modules.diffusionmodules.model.Encoder
          params:
            attn_type: vanilla
            double_z: True
            z_channels: 4
            resolution: 256
            in_channels: 3
            out_ch: 3
            ch: 128
            ch_mult: [ 1, 2, 4, 4 ]
            num_res_blocks: 2
            attn_resolutions: [ ]
            dropout: 0.0

        decoder_config:
          target: vwm.modules.autoencoding.temporal_ae.VideoDecoder
          params:
            attn_type: vanilla
            double_z: True
            z_channels: 4
            resolution: 256
            in_channels: 3
            out_ch: 3
            ch: 128
            ch_mult: [ 1, 2, 4, 4 ]
            num_res_blocks: 2
            attn_resolutions: [ ]
            dropout: 0.0
            video_kernel_size: [ 3, 1, 1 ]