_base_ = "../../_base_/default_runtime.py"

strategy = dict(
    type="FSDPStrategy",
    model_wrapper=dict(
        auto_wrap_policy=dict(type="size_based_auto_wrap_policy", min_num_params=1e8),
        use_orig_params=True,
        mixed_precision=dict(
            param_dtype="bfloat16",
            buffer_dtype="bfloat16",
            reduce_dtype="bfloat16",
            cast_forward_inputs=True,
        ),
        ignored_modules=["vae", "text_encoder.shared"],
    ),
    state_dict_cfg=dict(
        state_dict_type="FULL_STATE_DICT",
        state_dict_config=dict(
            type="FullStateDictConfig", offload_to_cpu=True, rank0_only=True
        ),
        optim_state_dict_config=dict(
            type="FullOptimStateDictConfig", offload_to_cpu=True, rank0_only=True
        ),
    ),
)

pipeline = [
    dict(
        type="LoadVideoAudioSegmentWithKeypointRef",
        video_path_key="video_path",
        keypoint_path_key="keypoint_path",
        audio_path_key="audio_path",
        filter_min_num_frames=18,
        segment_num_frames=18,
        sampling_rate=16000,
        segment_rule="random",
        video_only=False,
        frame_multiple=8,
        frame_multiple_add=2,
        use_ref_img=True,
        assert_fps=25,
        num_ref_img=1,
        ref_img_rule="random",
    ),
    dict(
        type="ResizeVideo",
        video_keys=["video", "ref_img"],
        size_candidates=[(512, 512)],
        keep_ratio=True,
    ),
    dict(type="CenterCropVideo", video_keys=["video", "ref_img"], crop_size=(512, 512)),
    dict(
        type="SapiensKeypoint2Mask",
        mask_area="lower_face",
        mask_expand=(0, 0, 0, 20),
    ),
    dict(
        type="NormalizeVideo",  # w.r.t dinov2
        video_keys=["video", "ref_img"],
        mean=[0.5, 0.5, 0.5],
        std=[0.5, 0.5, 0.5],
    ),
    dict(
        type="LoadText",
        text_path_key=None,
        dummy_captions=["Adjust the speaker’s mouth shapes based on the input audio."],
    ),
]


train_dataloader = dict(
    batch_size=1,
    num_workers=1,
    collate_fn=dict(type="flexible_collate"),
    sampler=dict(type="DefaultSampler", shuffle=True),
    dataset=dict(
        type="TextVideoAudioKeypointDataset",
        data_dir="data/",
        anno_file="data/annotations/train_anno.json",
        pipeline=pipeline,
        refetch=True,
    ),
)
train_cfg = dict(by_epoch=True, max_epochs=32)

optim_wrapper = dict(
    type="AmpOptimWrapper",
    dtype="bfloat16",
    optimizer=dict(type="AdamW", lr=1e-4, betas=[0.9, 0.99], weight_decay=0.0),
)


model = dict(
    type="WanVaceSlipmaeTrainerV5",
    # Base model from wan2.1-vace-1.3b
    vae=dict(
        type="AutoencoderKLWan",
        from_pretrained=dict(
            pretrained_model_name_or_path="checkpoints/Wan2.1-VACE-1.3B-diffusers",
            local_files_only=True,
            subfolder="vae",
        ),
    ),
    transformer=dict(
        type="AudiopackWanVACETransformer3DModel",
        from_pretrained=dict(
            pretrained_model_name_or_path="checkpoints/Wan2.1-VACE-1.3B-diffusers",
            local_files_only=True,
            subfolder="transformer",
            audio_inject="input",
            audio_dim=1024,
            audio_hidden_size=1536,
            low_cpu_mem_usage=False,
        ),
    ),
    tokenizer=dict(
        type="T5Tokenizer",
        from_pretrained=dict(
            pretrained_model_name_or_path="checkpoints/Wan2.1-VACE-1.3B-diffusers",
            local_files_only=True,
            subfolder="tokenizer",
        ),
    ),
    text_encoder=dict(
        type="UMT5EncoderModel",
        from_pretrained=dict(
            pretrained_model_name_or_path="checkpoints/Wan2.1-VACE-1.3B-diffusers",
            local_files_only=True,
            subfolder="text_encoder",
            low_cpu_mem_usage=False,
        ),
    ),
    scheduler=dict(
        type="FlowMatchEulerDiscreteScheduler",
        num_train_timesteps=1000,
        shift=5.0,
        use_dynamic_shifting=False,
        base_shift=0.5,
        max_shift=1.15,
    ),
    # audio
    audio_processor=dict(
        type="Wav2Vec2Processor",
        from_pretrained=dict(
            pretrained_model_name_or_path="checkpoints/wav2vec2-base-960h",
            local_files_only=True,
            use_fast=True,
        ),
    ),
    audio_encoder=dict(
        type="Wav2Vec2InterpModel",
        from_pretrained=dict(
            pretrained_model_name_or_path="checkpoints/wav2vec2-base-960h",
            local_files_only=True,
        ),
    ),
    audio_adapter=dict(
        type="AudioAdapter",
        in_feature=10752,
        out_feature=1024,
        init_cfg=dict(
            type="Pretrained",
            checkpoint="work_dirs/slipmae_pretrain/best_iter_58000.pth",
            prefix="audio_adapter.",
        ),
    ),
    # for slipmae
    slipmae_encoder=dict(  # 1024
        type="SlipmaeEncoder",
        arch="sapiens_0.3b",
        norm_in=True,  # the input is normalized with mean 0.5 and std 0.5, we need to renormalize with Sapiens mean and std
        img_size=(512, 512),
        patch_size=16,
        in_channels=3,
        drop_rate=0.0,
        drop_path_rate=0.0,
        qkv_bias=True,
        norm_cfg=dict(type="LN", eps=1e-6),
        frozen_stages=-1,
        interpolate_mode="bicubic",
        layer_scale_init_value=0.0,
        pre_norm=False,
        num_extra_tokens=3,
        mask_ratio=0.75,
        init_cfg=dict(
            type="Pretrained",
            checkpoint="work_dirs/slipmae_pretrain/best_iter_58000.pth",
            prefix="mae_encoder.",
        ),
    ),
    audio_drop_rate=0.0,
    vace_only=True,
)
