from functools import partial

from mmengine import FUNCTIONS
from torch.distributed.fsdp.wrap import (
    size_based_auto_wrap_policy,
    transformer_auto_wrap_policy,
)
from transformers.models.llama.modeling_llama import LlamaDecoderLayer

FUNCTIONS.register_module(
    name="size_based_auto_wrap_policy", module=size_based_auto_wrap_policy, force=True
)
FUNCTIONS.register_module(
    name="llama_auto_wrap_policy",
    module=partial(
        transformer_auto_wrap_policy, transformer_layer_cls=[LlamaDecoderLayer]
    ),
    force=True,
)
