_base_ = "../../_base_/default_runtime.py"

strategy = dict(
    type="FSDPStrategy",
    model_wrapper=dict(
        auto_wrap_policy=dict(type="size_based_auto_wrap_policy", min_num_params=1e8),
        use_orig_params=True,
        mixed_precision=dict(
            param_dtype="bfloat16",
            buffer_dtype="bfloat16",
            reduce_dtype="bfloat16",
            cast_forward_inputs=True,
        ),
    ),
    state_dict_cfg=dict(
        state_dict_type="FULL_STATE_DICT",
        state_dict_config=dict(
            type="FullStateDictConfig", offload_to_cpu=True, rank0_only=True
        ),
        optim_state_dict_config=dict(
            type="FullOptimStateDictConfig", offload_to_cpu=True, rank0_only=True
        ),
    ),
)


pipeline = [
    dict(
        type="LoadVideoAudioWithKeypointSegment",
        video_path_key="video_path",
        audio_path_key="audio_path",
        keypoint_path_key="keypoint_path",
        max_num_frames=256,
        sampling_rate=16000,
        segment_rule="random",
        video_only=False,
        strict_length=False,
    ),
    dict(
        type="ResizeVideo",
        video_keys=["video"],
        size_candidates=[(512, 512)],
        keep_ratio=True,
    ),
    dict(type="CenterCropVideo", video_keys=["video"], crop_size=(512, 512)),
    dict(
        type="NormalizeVideo",  # w.r.t dinov2
        video_keys=["video"],
        mean=[0.5, 0.5, 0.5],
        std=[0.5, 0.5, 0.5],
    ),
]

model = dict(
    type="TrainerSlipmaePretrain",
    audio_processor=dict(
        type="Wav2Vec2Processor",
        from_pretrained=dict(
            pretrained_model_name_or_path="checkpoints/wav2vec2-base-960h",
            local_files_only=True,
            use_fast=True,
        ),
    ),
    audio_encoder=dict(
        type="Wav2Vec2InterpModel",
        from_pretrained=dict(
            pretrained_model_name_or_path="checkpoints/wav2vec2-base-960h",
            local_files_only=True,
        ),
    ),
    mae_encoder=dict(  # 1024
        type="SlipmaeEncoder",
        arch="sapiens_0.3b",
        norm_in=True,  # the input is normalized with mean 0.5 and std 0.5, we need to renormalize with Sapiens mean and std
        img_size=(512, 512),
        patch_size=16,
        in_channels=3,
        drop_rate=0.0,
        drop_path_rate=0.0,
        qkv_bias=True,
        norm_cfg=dict(type="LN", eps=1e-6),
        frozen_stages=-1,
        interpolate_mode="bicubic",
        layer_scale_init_value=0.0,
        pre_norm=False,
        num_extra_tokens=3,
        mask_ratio=0.75,
        init_cfg=dict(
            type="Pretrained",
            checkpoint="checkpoints/sapiens-pose-0.3b/backbone.pth",
        ),
    ),
    mae_decoder=dict(
        type="SlipmaeDecoder",
        img_size=(512, 512),
        patch_size=16,
        in_chans=3,
        embed_dim=1024,
        decoder_embed_dim=512,
        decoder_depth=8,
        decoder_num_heads=16,
        mlp_ratio=4,
        norm_cfg=dict(type="LN", eps=1e-6),
        predict_feature_dim=None,
        num_extra_tokens=3,  # 3 types of condition: id, nonvocal, vocal
    ),
    train_minibatch=48,
    loss_cfg=dict(
        temporal_neighbors=0,  # treat temporal neighbor frames as positive samples for contrastive loss
        pixel_loss_weight=1,
        cl_loss_weight=1,
        ortho_loss_weight=1e-1,
        ortho_loss_mode="cov",
    ),
    audio_layer="all",
    audio_adapter=dict(
        in_feature=10752,
        out_feature=1024,
    ),
    audio_filter_model=dict(
        type="Wav2Vec2InterpModel",
        from_pretrained=dict(
            pretrained_model_name_or_path="checkpoints/wav2vec2-base-960h",
            local_files_only=True,
        ),
    ),
    motion_augment="color",
    init_cfg=dict(
        type="Pretrained",
        checkpoint="work_dirs/slipmae_pretrain/iter_273000.pth",
    ),
)

train_dataloader = dict(
    batch_size=1,
    num_workers=8,
    sampler=dict(type="DefaultSampler", shuffle=True),
    collate_fn=dict(type="flexible_collate"),
    dataset=dict(
        type="TextVideoAudioKeypointDataset",
        data_dir="data/",
        anno_file="data/annotations/train_anno.json",
        pipeline=pipeline,
        refetch=True,
    ),
)

train_cfg = dict(by_epoch=True, max_epochs=32, val_interval=32)

optim_wrapper = dict(
    type="AmpOptimWrapper",
    dtype="bfloat16",
    optimizer=dict(type="AdamW", lr=1e-5, betas=[0.9, 0.99], weight_decay=0.0),
)
