import logging

import torch
from torch.testing._internal import common_utils

from apex.transformer import parallel_state
from apex.transformer.tensor_parallel import mappings
from apex.transformer.testing.distributed_test_base import NcclDistributedTestBase
from apex.transformer.testing.distributed_test_base import UccDistributedTestBase


logging.getLogger("torch").setLevel(logging.WARNING)
logging.getLogger("apex").setLevel(logging.WARNING)


class MappingTestBase:
    def test_reduce(self):
        for tensor_model_paralell_world_size in range(1, self.world_size + 1):
            if self.world_size % tensor_model_paralell_world_size > 0:
                continue
            parallel_state.initialize_model_parallel(
                tensor_model_parallel_size_=tensor_model_paralell_world_size
            )
            t = torch.full((10, 10, 10, 10), 50, device=f"cuda:{self.rank}")
            expected = torch.full(
                (10, 10, 10, 10),
                50 * tensor_model_paralell_world_size,
                device=f"cuda:{self.rank}",
            )
            self.assertTrue(
                torch.equal(mappings._reduce(t), expected),
                msg=f"tensor_model_paralell_world_size: {tensor_model_paralell_world_size}",
            )
            parallel_state.destroy_model_parallel()

    def test_split(self):
        for tensor_model_paralell_world_size in range(1, self.world_size + 1):
            if self.world_size % tensor_model_paralell_world_size > 0:
                continue
            parallel_state.initialize_model_parallel(
                tensor_model_parallel_size_=tensor_model_paralell_world_size
            )

            tensors = [
                torch.randn(10, 1)
                for _ in range(tensor_model_paralell_world_size)
            ]
            x = torch.cat(tensors, 1)
            out = mappings._split_along_last_dim(x)
            self.assertTrue(
                torch.equal(
                    out, tensors[parallel_state.get_tensor_model_parallel_rank()]
                ),
                msg=f"tensor_model_paralell_world_size: {tensor_model_paralell_world_size}"
            )
            parallel_state.destroy_model_parallel()

    def test_gather(self):
        for tensor_model_paralell_world_size in range(1, self.world_size + 1):
            if self.world_size % tensor_model_paralell_world_size > 0:
                continue
            parallel_state.initialize_model_parallel(
                tensor_model_parallel_size_=tensor_model_paralell_world_size
            )
            device = f"cuda:{self.rank}"
            gathered = mappings._gather_along_last_dim(
                torch.tensor(
                    [parallel_state.get_tensor_model_parallel_rank()], device=device
                )
            )
            expected = torch.tensor(
                [rank for rank in range(tensor_model_paralell_world_size)],
                device=device,
            )
            self.assertTrue(
                torch.equal(gathered, expected),
                msg=f"tensor_model_paralell_world_size: {tensor_model_paralell_world_size}",
            )
            parallel_state.destroy_model_parallel()


class NcclMappingTest(MappingTestBase, NcclDistributedTestBase): pass
class UccMappingTest(MappingTestBase, UccDistributedTestBase): pass


if __name__ == "__main__":
    common_utils.run_tests()
