model:
  base_learning_rate: 2.0e-05 # 1.5e-04
  scale_lr: False
  ref_model_checkpoint: checkpoints/t2v-turbo/t2v-turbo_ref_model.ckpt
  pretrained_checkpoint: checkpoints/t2v-turbo/t2v-turbo_model.ckpt
  target: lvdm.models.ddpm3d.T2VTurboDPO
  lr_scheduler:
    target: torch.optim.lr_scheduler.OneCycleLR 
    params:
      max_lr: 2.0e-5
      pct_start: 0.2 
      anneal_strategy: cos
      final_div_factor: 10 
  params:
    pretrained_unet_path: checkpoints/t2v-turbo/unet_lora.pt
    loss_type: dpo
    log_every_t: 200
    linear_start: 0.00085
    linear_end: 0.012
    num_timesteps_cond: 1
    timesteps: 1000
    first_stage_key: video
    cond_stage_key: caption
    cond_stage_trainable: false
    conditioning_key: crossattn
    image_size:
    - 40
    - 64
    channels: 4
    scale_by_std: false
    scale_factor: 0.18215
    use_ema: false
    uncond_type: empty_seq
    monitor: train/loss_simple
    encoder_type: 2d
    use_scale: true
    scale_b: 0.7
    unet_config:
      target: lvdm.modules.networks.openaimodel3d.UNetModel
      params:
        time_cond_proj_dim: 256 
        in_channels: 4
        out_channels: 4
        model_channels: 320
        attention_resolutions:
        - 4
        - 2
        - 1
        num_res_blocks: 2
        channel_mult:
        - 1
        - 2
        - 4
        - 4
        num_head_channels: 64
        transformer_depth: 1
        context_dim: 1024
        use_linear: true
        use_checkpoint: true
        temporal_conv: true 
        temporal_attention: true
        temporal_selfatt_only: true
        use_relative_position: false 
        use_causal_attention: false
        temporal_length: 16
        addition_attention: true
        fps_cond: true 
    first_stage_config:
      target: lvdm.models.autoencoder.AutoencoderKL
      params:
        embed_dim: 4
        monitor: val/rec_loss
        ddconfig:
          double_z: true
          z_channels: 4
          resolution: 512
          in_channels: 3
          out_ch: 3
          ch: 128
          ch_mult:
          - 1
          - 2
          - 4
          - 4
          num_res_blocks: 2
          attn_resolutions: []
          dropout: 0.0
        lossconfig:
          target: torch.nn.Identity
    cond_stage_config:
      target: lvdm.modules.encoders.condition.FrozenOpenCLIPEmbedder
      params:
        freeze: true
        layer: penultimate
        
data:
  target: data.lightning_data.DataModuleFromConfig
  params:
    batch_size: 16
    num_workers: 16
    wrap: false
    train:
      target: data.video_data.TextVideoDPO
      params:
        data_root: configs/t2v_turbo_dpo/vidpro/train_data.yaml
        resolution: [320, 512]
        video_length: 2
        subset_split: all

lightning:  
  trainer:
    benchmark: True
    log_every_n_steps: 10
    num_workers: 16
    num_nodes: 1
    accumulate_grad_batches: 4
    max_epochs: 5 
  callbacks:
    model_checkpoint:
      target: pytorch_lightning.callbacks.ModelCheckpoint
      params:
        every_n_epochs: 1
        filename: "{epoch:04}-{step:06}"
        save_weights_only: True
    metrics_over_trainsteps_checkpoint:
      target: pytorch_lightning.callbacks.ModelCheckpoint
      params:
        filename: "{epoch:06}-{step:09}"
        save_weights_only: True
        every_n_train_steps: 25
    