name: mistral-video
tags: ["svd"]
description: ""
version: 'webvlaion+latte' # if not specified, will be set to version_{index}
output_dir: "outputs/${name}"

extras:
  num_video_frames: 16
  num_image_frames: 8
  tokenizer_path: /mnt/petrelfs/wenhao1/LLaMA2-Accessory/Large-DiT-T2I/pretrain/Llama-2-7b-hf
  model_path: /mnt/petrelfs/wenhao1/LLaMA2-Accessory/Large-DiT-T2I/pretrain/Llama-2-7b-hf
  num_replicas: 16 # !important same as nodes * devices
  frame_interval: 6
  size: 256

seed: 42
# resume: outputs/mistral-video/xrotcaaz/checkpoints/epoch=0-step=1.ckpt
data:
  _target_: datasets.datamodule.VideoDataModule
  train_dataset:
    _target_: datasets.webvideolaion_datasets_v2.WebVideoLaion
    _partial_: true # !important
    video_ceph_dir: webvideo
    num_video_frames: ${extras.num_video_frames}
    num_image_frames: ${extras.num_image_frames}
    frame_interval: ${extras.frame_interval}
    video_data_path_list: 
      - /mnt/petrelfs/share_data/maxin/datasets/webvid/aes_4plus_sub_files_128/ # webv videos
    image_data_path_list:
      - /mnt/petrelfs/share_data/maxin/datasets/laion5b/sub_files_128/ # pixabay images
    num_replicas: ${extras.num_replicas}
    temporal_sample:
      _target_: datasets.video_transforms.TemporalRandomCrop
      _partial_: true
    video_transform:
      _target_: torchvision.transforms.Compose
      transforms:
        - _target_: datasets.video_transforms.ToTensorVideo
        - _target_: datasets.video_transforms.ScaleResize16xVideo
          size: ${extras.size}
        - _target_: torchvision.transforms.Normalize
          mean: [0.5, 0.5, 0.5]
          std: [0.5, 0.5, 0.5]
          inplace: True
  train_batch_size: 1

  val_dataset:
    _target_: datasets.dummy.DummyVideoDataset
    num_video_frames: ${extras.num_video_frames}
    prompt_list: [
              'A man talking in the classroom.',
              'A woman walking in the garden.',
              # 'A man playing with guitar.',
              # 'A Man talking to a woman',
              # 'A man palying the basketball.',
              # 'A man is dancing,',
              # 'A woman is dancing',
              # 'The man and the woman are dancing',
              # 'A child is eating',
              # 'A child is sleeping',
              # 'A man is working out',
              # 'A woman working out',
              # 'A man is swimming',
              # 'A woman is swimming',
              # 'A child is doing his homework',
              # 'A man at work',
    ]
    size: ${extras.size}
  val_batch_size: 1
  # val_dataset:
  #   _target_: datasets.pixabaylaion_datasets.PixabayLaion
  #   _partial_: true # !important
  #   num_video_frames: ${extras.num_video_frames}
  #   num_image_frames: ${extras.num_image_frames}
  #   frame_interval: ${extras.frame_interval}
  #   num_samples: 1
  #   video_data_path_list: 
  #     - /mnt/petrelfs/share_data/maxin/datasets/pixabay/aes_4plus_sub_files_128/ # pixabay videos
  #   image_data_path_list:
  #     - /mnt/petrelfs/share_data/maxin/datasets/pixabay_imgs/aes_4_sub_files_128/ # pixabay images
  #   num_replicas: ${extras.num_replicas}
  #   temporal_sample:
  #     _target_: datasets.video_transforms.TemporalRandomCrop
  #     _partial_: true
  #   video_transform:
  #     _target_: torchvision.transforms.Compose
  #     transforms:
  #       - _target_: datasets.video_transforms.ToTensorVideo
  #       - _target_: datasets.video_transforms.ScaleResize16xVideo
  #         size: ${extras.size}
  #       - _target_: torchvision.transforms.Normalize
  #         mean: [0.5, 0.5, 0.5]
  #         std: [0.5, 0.5, 0.5]
  #         inplace: True
  # val_batch_size: 1

  test_dataset:
    _target_: datasets.dummy.DummyVideoDataset
    num_video_frames: ${extras.num_video_frames}
    prompt_list: [] # prompt.txt
    size: ${extras.size}
  test_batch_size: 1

  num_workers: 32
  pin_memory: True

system:
  _target_: systems.svd_system.SVDSystem
  lr: 1.0e-5
  use_ema: True
  num_video_frames: ${extras.num_video_frames}
  num_image_frames: ${extras.num_image_frames}
  text_model_path: /mnt/petrelfs/share_data/maxin/work/pretrained/Llama-2-7b-hf
  dit:
    _target_: models.mistral.model.Latte_XL_2
    positional_embeddings: rope2d
    max_hw_len: 256
    max_video_frame: ${extras.num_video_frames}
    max_image_frame: ${extras.num_image_frames}
    # norm_layer: 
    #   _target_: models.mistral.model.AdaRMSNorm
    #   _partial_: true
  # scheduler:
  #   _target_: diffusers.EulerDiscreteScheduler
  #   beta_start: 0.0001
  #   beta_end: 0.02
  #   beta_schedule: linear
  #   timestep_spacing: leading
  #   num_train_timesteps: 1000
  #   prediction_type: epsilon
  #   interpolation_type: linear
  #   clip_sample: false
  #   sample_max_value: 1.0
  #   set_alpha_to_one: false
  #   skip_prk_steps: true
  #   use_karras_sigmas: false
  #   trained_betas: null  # come from sd xl
  scheduler:
    # _target_: diffusers.DDIMScheduler
    _target_: diffusers.DDPMScheduler
    beta_start: 0.0001
    beta_end: 0.02
    beta_schedule: linear

  base_model_id: stabilityai/stable-video-diffusion-img2vid
  variant: fp16
  cfg: 0.1
  num_inference_steps: 1000
  


trainer:
  _target_: lightning.Trainer
  use_distributed_sampler: False # !important  use custom sampler
  default_root_dir: ${output_dir}
  max_steps: 1000000
  # check_val_every_n_epoch: 5
  # val_check_interval: 8
  val_check_interval: 20000
  log_every_n_steps: 10
  num_sanity_val_steps: 1
  enable_progress_bar: true
  accumulate_grad_batches: 8
  strategy: ddp_find_unused_parameters_true
  # strategy: deepspeed_stage_1
  accelerator: gpu
  # devices: 2
  # num_nodes: 1
  num_nodes: 2
  devices: 8
  # precision: 16-mixed # mixed precision for extra speed-up
  precision: bf16-true
  gradient_clip_val: 1.0

callbacks:
  model_checkpoint:
    _target_: lightning.pytorch.callbacks.ModelCheckpoint
    save_top_k: -1
    # every_n_train_steps: 1
    every_n_train_steps: 2500
    # dirpath: wenhao:s3://wenhao/${name}/${version}/checkpoints
  data_load:
    _target_: callbacks.DataCallback


logger:
  # tensorboard:
  #   _target_: lightning.pytorch.loggers.tensorboard.TensorBoardLogger
  #   save_dir: "${output_dir}"
  #   name: ""
  #   version: "${version}"
  # #   sub_dir: "tb_logs"
  wandb:
    _target_: lightning.pytorch.loggers.WandbLogger
    project: "${name}"
    save_dir: "outputs"
    name: "${version}"
