_base_ = "../../_base_/default_runtime.py"

strategy = dict(
    type="FSDPStrategy",
    model_wrapper=dict(
        auto_wrap_policy=dict(type="size_based_auto_wrap_policy", min_num_params=1e8),
        use_orig_params=True,
        mixed_precision=dict(
            param_dtype="bfloat16",
            buffer_dtype="bfloat16",
            reduce_dtype="bfloat16",
            cast_forward_inputs=True,
        ),
    ),
    state_dict_cfg=dict(
        state_dict_type="FULL_STATE_DICT",
        state_dict_config=dict(
            type="FullStateDictConfig", offload_to_cpu=True, rank0_only=True
        ),
        optim_state_dict_config=dict(
            type="FullOptimStateDictConfig", offload_to_cpu=True, rank0_only=True
        ),
    ),
)


pipeline = [
    dict(
        type="LoadVideoWithLabelSegment",
        video_path_key="video_path",
        audio_path_key="audio_path",
        label_key="emotion",
        max_num_frames=5,
        sampling_rate=16000,
        segment_rule="random",
        video_only=True,
        strict_length=True,
    ),
    dict(
        type="ResizeVideo",
        video_keys=["video"],
        size_candidates=[(512, 512)],
        keep_ratio=True,
    ),
    dict(type="CenterCropVideo", video_keys=["video"], crop_size=(512, 512)),
    dict(
        type="NormalizeVideo",
        video_keys=["video"],
        mean=[0.5, 0.5, 0.5],
        std=[0.5, 0.5, 0.5],
    ),
]

model = dict(
    type="TrainerSlipmaeClassifier",
    backbone=dict(  # 1024
        type="SapiensMotionExtractorV2",
        arch="sapiens_0.3b",
        norm_in=True,  # the input is normalized with mean 0.5 and std 0.5, we need to renormalize with Sapiens mean and std
        img_size=(512, 512),
        patch_size=16,
        in_channels=3,
        drop_rate=0.0,
        drop_path_rate=0.0,
        qkv_bias=True,
        norm_cfg=dict(type="LN", eps=1e-6),
        frozen_stages=-1,
        interpolate_mode="bicubic",
        layer_scale_init_value=0.0,
        pre_norm=False,
        num_extra_tokens=3,
        mask_ratio=0.75,  # make no sense, because we do not use mask when finetuning.
        init_cfg=dict(
            type="Pretrained",
            checkpoint="work_dirs/motion_extractor_mae_v3_crossattn_shuffleframe/iter_168000.pth",
            prefix="mae_encoder.",
        ),
    ),
    task="multiclass",  # multiclass for appearance, multilabel for action and emotion
    num_classes=8,  # 8 for emotion, 40 for appearance, 35 for action
    label_key="emotion",
)

train_dataloader = dict(
    batch_size=3,
    num_workers=8,
    sampler=dict(type="DefaultSampler", shuffle=True),
    collate_fn=dict(type="flexible_collate"),
    dataset=dict(
        type="TextVideoAudioLabelDataset",
        data_dir="data/",
        anno_file="data/annotations/train_celebvhq_mead_anno.json",
        pipeline=pipeline,
        refetch=True,
    ),
)


train_cfg = dict(by_epoch=True, max_epochs=100, val_interval=100)

optim_wrapper = dict(
    type="AmpOptimWrapper",
    dtype="bfloat16",
    optimizer=dict(type="AdamW", lr=5e-5, betas=[0.9, 0.99], weight_decay=0.0),
)
