import logging

import torch
from torch.testing._internal import common_utils

logging.getLogger("torch").setLevel(logging.WARNING)

from apex.transformer import parallel_state
from apex.transformer.tensor_parallel import utils
from apex.transformer.testing.distributed_test_base import NcclDistributedTestBase

logging.getLogger("apex").setLevel(logging.WARNING)


class TransformerUtilsTest(NcclDistributedTestBase):
    def test_split_tensor_along_last_dim(self):
        for tensor_model_paralell_world_size in range(1, self.world_size + 1):
            if self.world_size % tensor_model_paralell_world_size > 0:
                continue
            parallel_state.initialize_model_parallel(
                tensor_model_parallel_size_=tensor_model_paralell_world_size
            )

            device = "cpu"
            input_tensor = torch.randn((100, 100, 100), device=device)
            splits = utils.split_tensor_along_last_dim(input_tensor, 10)
            last_dim_shapes = torch.tensor(
                [int(split.size()[-1]) for split in splits]
            )

            self.assertTrue(
                torch.equal(last_dim_shapes, torch.full((10,), 10),),
                msg=f"tensor_model_paralell_world_size: {tensor_model_paralell_world_size}",
            )

            parallel_state.destroy_model_parallel()


if __name__ == "__main__":
    common_utils.run_tests()
