model:
  class_name: models.video_latent_flow_matching_ar.VideoLatentFlowMatching
  generator_ckpt: False # /mnt/localssd/minVid/ckpt/wan/Wan2.1-T2V-1.3B/
  drop_text_prob: 0.0
  ar_window_size: 3 # three latent frame each
  adjust_timestep_shift: true
  num_repeat: 1 # 1 # 2
  frame_independent_noise: true
  logit_normal_weighting: true
  diffusion_config:
    class_name: minVid.models.wan.wan_warpper_versatile.WanDiffusionWrapper
    generator_ckpt: /mnt/localssd/minVid/ckpt/wan/Wan2.1-T2V-1.3B/diffusion_pytorch_model.safetensors
    shift: 3.0
    extra_one_step: true
    model_config:
      class_name: minVid.models.wan.wan_base.modules.wan_model_warpper.WanModel
      model_type: t2v
      patch_size: 
        - 1
        - 2
        - 2
      in_dim: 16
      dim: 1536
      ffn_dim: 8960
      freq_dim: 256
      num_heads: 12
      num_layers: 30
      out_dim: 16
      eps: 1e-06
      text_len: 512
      attn_every_n_layers: 100 # full global attention
      efficient_attn_config:
        - class_name: minVid.models.blocks.ar_lact_swa_repeat.ARFastWeightSwiGLU
          dim: 1536
          num_heads: 12
          qk_norm: true
          o_norm: true
          local_window_size: 4680 # size for window attention 
          update_every: 9360 # 14040 if repeat=2
          qk_l2_norm: true 
          qkv_silu: false
          w_init: clean
          inter_multi: 2
          lr_dim: 1
          fw_head_dim: 768 #  768
          use_moun: false # false
          num_moun_iters: 5
          weight_norm: true
          ttt_scale: 1.0
          learnable_ttt_scale: true # false 
          ar_window_f: 3 # number of latent frames. used for correct rope implementation. 
          batch_size: 1 # make sure to fill in this! used for reshaping the input tokens! 
          n_latent_f: 21 # number of latent frames. used for correct rope implementation. 
  text_encoder_config:
    class_name: minVid.models.wan.wan_text_vae_warpper.WanTextEncoder
    model_path: /mnt/localssd/minVid/ckpt/wan/Wan2.1-T2V-1.3B/models_t5_umt5-xxl-enc-bf16.pth
  vae_config:
    class_name: minVid.models.wan.wan_text_vae_warpper.WanVAEWrapper
    model_path: /mnt/localssd/minVid/ckpt/wan/Wan2.1-T2V-1.3B/Wan2.1_VAE.pth
  timestep_shift: 3.0
  timestep_sample_method: uniform
  denoising_loss_type: flow
  num_train_timestep: 1000 # not used for continious flow matching
  mixed_precision: true
  sample_timestep_shift: 3.0

# training related
exp_name: debug
output_path: /mnt/localssd/lact_ar_video_exp

seed: 1
deterministic: false

# distributed training related

mixed_precision: true
sharding_strategy: hybrid_full
text_encoder_fsdp_wrap_strategy: size
text_encoder_fsdp_transformer_module: "transformer"

 
lr: 0.00002 # 5e-6
beta1: 0.9
beta2: 0.95
weight_decay: 0.01 #  0.0 # 0.01
lr_scheduler_type: cosine
max_fwdbwd_passes: 10000
warmup: 1000
grad_accum_steps: 1 # not accum
grad_clip_norm: 1.0

wandb_host: "xxx" # replace with your wandb host
api_key_path: "./api_keys.yaml" # please put your api key here. 
wandb_entity: xxx # replace with your wandb entity
wandb_project: xxx # replace with your wandb project
wandb_log_every: 10
save_every: 1000

keep_last_iter: 2000

batch_size_per_gpu: 1

train:
  attn_only: true # only train the attention blocks. setting to false will train the whole model. 
  first_stage: false
  train_disable_auto_gc: true
  train_manual_gc_interval: 200
  use_ema: true
  ema_weight: 0.995
  lr_multiplier: 10.0 # learning rate multiplier for the new weights added by the test-time training part.
  continue_training: false
  fsdp_strategy: full
  fsdp_wrap_policy: module
  fsdp_modules: 
    - minVid.models.wan.wan_base.modules.wan_model_warpper.WanAttentionBlock

# [480, 832].  5s, 16fps => 81 frame
# [81, 3, 480, 832]
dataset_train:
  target: gpu_pin_hybrid_loader
  params:
