# Copyright (c) Alibaba, Inc. and its affiliates.
from megatron.core.dist_checkpointing.strategies.torch import (
    TorchDistSaveShardedStrategy,
)
from megatron.training import get_args, global_vars, initialize, training

from swift.utils import get_logger

logger = get_logger()


def patch_megatron_tokenizer(tokenizer):

    def build_tokenizer(args):
        return tokenizer

    global_vars.build_tokenizer = build_tokenizer


def patch_torch_dist_shard(thread_count):
    __init__ = TorchDistSaveShardedStrategy.__init__

    def __new_init__(*args, **kwargs):
        kwargs["thread_count"] = thread_count
        return __init__(*args, **kwargs)

    TorchDistSaveShardedStrategy.__init__ = __new_init__
