# Copyright (c) 2023, NVIDIA CORPORATION. All rights reserved.

"""Utility functions for Megatron optimizer."""


from megatron.core import mpu


def shard_buffer(buffer):
    """
    Shard buffer into dp_size chunks of equal size.
    """
    data_parallel_world_size = mpu.get_data_parallel_world_size(with_context_parallel=True)
    assert buffer.numel() % data_parallel_world_size == 0
    shard_size = buffer.numel() // data_parallel_world_size
    sharded_buffer = [
        buffer[(r * shard_size) : ((r + 1) * shard_size)] for r in range(data_parallel_world_size)
    ]
    return sharded_buffer
