# mypy: allow-untyped-defs

import copy
import json
import itertools
import math
import os
import random
import sys
import tempfile
import time
from collections import namedtuple, OrderedDict, defaultdict
from contextlib import contextmanager, nullcontext
from dataclasses import dataclass
from datetime import timedelta
from functools import reduce
from typing import Union, NamedTuple, Callable, Any
import unittest
import numpy as np
import torch
import torch.cuda
import torch.distributed as dist
import torch.distributed.algorithms.model_averaging.averagers as averagers
import torch.distributed.algorithms.model_averaging.hierarchical_model_averager as hierarchicalSGD
import torch.distributed.algorithms.model_averaging.utils as model_averaging_utils
import torch.nn as nn
import torch.nn.functional as F
from torch._utils_internal import TEST_MASTER_ADDR as MASTER_ADDR
from torch._utils_internal import TEST_MASTER_PORT as MASTER_PORT
from torch.utils._python_dispatch import TorchDispatchMode
from torch.autograd import DeviceType
from torch.cuda.amp import GradScaler, autocast

from torch.distributed.algorithms.ddp_comm_hooks import (
    post_localSGD_hook as post_localSGD,
    powerSGD_hook as powerSGD,
    default_hooks as default,
    quantization as quantization_hooks,
)
from torch.distributed.optim import _apply_optimizer_in_backward

from torch.distributed.distributed_c10d import (
    get_world_size,
    _get_default_group,
    _get_pg_config,
)
from torch.distributed.utils import (
    _verify_param_shape_across_processes,
    _sync_module_states,
)
from torch.profiler import (
    ExecutionTraceObserver,
    ProfilerActivity,
)

from torch.nn.parallel import DistributedDataParallel
from torch.nn.parallel.distributed import _dump_DDP_relevant_env_vars, _MixedPrecision
from torch.testing._internal.common_distributed import (
    MultiProcessTestCase,
    TEST_SKIPS,
    init_multigpu_helper,
    initialize_temp_directories,
    cleanup_temp_dir,
    simple_sparse_reduce_tests,
    skip_if_rocm_multiprocess,
    skip_if_small_worldsize,
    skip_if_odd_worldsize,
    skip_if_lt_x_gpu,
    nccl_skip_if_lt_x_gpu,
    skip_if_no_gpu,
    require_n_gpus_for_nccl_backend,
    requires_nccl_version,
    captured_output,
    with_nccl_blocking_wait,
    with_dist_debug_levels,
    verify_ddp_error_logged,
    DistTestCases,
)
from torch.testing._internal.common_utils import (
    instantiate_parametrized_tests,
    IS_MACOS,
    IS_WINDOWS,
    FILE_SCHEMA,
    IS_FBCODE,
    NO_MULTIPROCESSING_SPAWN,
    IS_SANDCASTLE,
    skip_but_pass_in_sandcastle,
    skip_but_pass_in_sandcastle_if,
)

import torch.distributed.optim.post_localSGD_optimizer as post_localSGD_optimizer

from torch.utils.data.distributed import DistributedSampler
import operator

try:
    import torchvision

    HAS_TORCHVISION = True
except ImportError:
    HAS_TORCHVISION = False

if sys.platform == "win32":
    import msvcrt
else:
    import fcntl


class NetWithBuffers(nn.Module):
    def __init__(self) -> None:
        super().__init__()
        self.a = nn.Linear(10, 10, bias=False)
        self.b = nn.Linear(10, 1, bias=False)
        self.register_buffer("buffer", torch.randn(1, 2))

    def forward(self, x):
        self.buffer.add_(1)
        return self.b(self.a(x))


class Foo:
    def __init__(self, x):
        # Can be tensor or int
        self.x = x

    def __eq__(self, other):
        def eq(value, other):
            if isinstance(value, torch.Tensor):
                return torch.equal(value, other)
            return value == other

        for attr, value in self.__dict__.items():
            other_value = other.__dict__[attr]
            if not eq(value, other_value):
                return False
        return True


f = Foo(10)
f.bar = 1

foo_cpu_tensor = Foo(torch.randn(3, 3))


COLLECTIVES_OBJECT_TEST_LIST = [
    {"key1": 3, "key2": 4, "key3": {"nested": True}},
    f,
    foo_cpu_tensor,
    "foo",
    [1, 2, True, "string", [4, 5, "nested"]],
]

# Allowlist of distributed backends where profiling collectives is supported.
PROFILING_SUPPORTED_BACKENDS = [
    dist.Backend.NCCL,
    dist.Backend.GLOO,
    dist.Backend.MPI,
    dist.Backend.UCC,
]

# Allowlist of distributed backends where profiling is supported with use_cuda=True
CUDA_PROFILING_SUPPORTED_BACKENDS = [
    dist.Backend.GLOO,
    dist.Backend.MPI,
    dist.Backend.NCCL,
    dist.Backend.UCC,
]

# Allowlist of distributed backends where profiling is supported for p2p ops
SEND_RECV_PROFILING_SUPPORTED_BACKENDS = [
    dist.Backend.MPI,
    dist.Backend.GLOO,
    dist.Backend.NCCL,
    dist.Backend.UCC,
]

# Dummy NamedTuple data structures to test DDP support for NamedTuple types.
EXPECTED_FIELDS = ("a", "b")
TestNamedTupleInput_0 = namedtuple("NamedTuple", EXPECTED_FIELDS)


class TestNamedTupleInput_1(NamedTuple):
    a: torch.tensor
    b: torch.tensor


skipIfNoTorchVision = skip_but_pass_in_sandcastle_if(
    not HAS_TORCHVISION, "no torchvision"
)

BACKEND = os.environ["BACKEND"]
INIT_METHOD = os.getenv("INIT_METHOD", "env://")

DEFAULT_TIMEOUT = 300
CUSTOMIZED_TIMEOUT = {"test_DistributedDataParallel": 500}


def get_profiling_event(event_name, profiler, dedup_gpu_user_annotation=False):
    event_list = (
        profiler.events()
        if isinstance(profiler, torch.profiler.profile)
        else profiler.function_events
    )
    return [
        event for event in event_list
        if (
            (event.name.endswith(event_name) or event.name.startswith(event_name))
            and (not dedup_gpu_user_annotation or event.device_type != DeviceType.CUDA)
        )
    ]

def get_profiler_nccl_meta(prof):
    """Torch profiler includes nccl metadata in an inserted operator called "record_param_comms"
    We will need to test metadata obtained from profiler here"""
    tf = tempfile.NamedTemporaryFile(
        mode="w+t", suffix=".json", delete=False
    )
    tf.close()
    trace_file = tf.name

    prof.export_chrome_trace(trace_file)
    with open(trace_file) as f:
        events = json.load(f)["traceEvents"]
    print(f"Trace saved to {trace_file}")

    # Comment to debug
    os.remove(trace_file)

    return [e for e in events if e.get("name") == "record_param_comms"]

# Base error message substring on unfinished reductions.
ddp_prev_reduction_unfinished_str = (
    "Expected to have finished reduction in the prior iteration"
)
# Error message substring when find_unused_parameters=True has not been passed
ddp_recommend_find_unused_params_str = (
    "passing the keyword argument `find_unused_parameters=True`"
)
# Error message substring when find_unused_parameters=True is enabled
ddp_find_unused_params_enabled_str = "Since `find_unused_parameters=True` is enabled"
# Error message substring for possibility of not all model outputs being used
# in loss computation
ddp_outputs_not_used_in_loss_str = (
    "`forward` function outputs participate in calculating loss"
)
# Error message substring suggesting to use TORCH_DISTRIBUTED_DEBUG
ddp_suggest_debug_mode_str = (
    "set the environment variable TORCH_DISTRIBUTED_DEBUG to either INFO or DETAIL"
)


class DDPUnevenTestInput(NamedTuple):
    name: str
    model: nn.Module
    inp: Union[torch.tensor, tuple]
    sync_interval: int
    throw_on_early_termination: bool = False
    hook: Callable = None
    state: Any = None


class _FC2(nn.Module):
    def __init__(self) -> None:
        super().__init__()
        self.fc = nn.Linear(10, 50, bias=True)
        self.fc.bias.requires_grad = False

    def forward(self, x):
        x = self.fc(x)
        return x


class Net(nn.Module):
    def __init__(self) -> None:
        super().__init__()
        self.fc1 = nn.Linear(2, 10, bias=False)
        self.fc2 = _FC2()
        self.fc3 = nn.Linear(50, 4, bias=False)
        self.relu = nn.ReLU()
        self.no_grad_param = nn.Parameter(
            torch.tensor([2, 2]).long(), requires_grad=False
        )

    def forward(self, x):
        x = self.relu(self.fc1(x))
        x = self.relu(self.fc2(x))
        x = self.fc3(x)
        return F.softmax(x, dim=1)


class LargeNet(nn.Module):
    def __init__(self) -> None:
        super().__init__()
        self.fc1 = nn.Linear(1000, 2000, bias=False)
        self.fc2 = nn.Linear(2000, 500, bias=False)

    def forward(self, x):
        x = self.fc1(x)
        x = self.fc2(x)
        return x


class Task(nn.Module):
    def __init__(self) -> None:
        super().__init__()
        self.p = nn.Parameter(torch.ones(2, 2))

    def forward(self, x):
        return self.p + x


class BatchNormNet(nn.Module):
    def __init__(self, affine=True):
        super().__init__()
        self.fc1 = nn.Linear(2, 40, bias=False)
        self.bn = nn.BatchNorm1d(4, affine=affine)
        self.fc2 = nn.Linear(40, 4, bias=False)

    def forward(self, x):
        x = torch.reshape(self.fc1(x), (-1, 4, 10))
        x = self.bn(x)
        x = torch.reshape(x, (-1, 40))
        x = self.fc2(x)
        return F.softmax(x, dim=1)


class UnusedParamTwoLinLayerNet(nn.Module):
    def __init__(self) -> None:
        super().__init__()
        self.a = nn.Linear(10, 10, bias=False)
        self.b = nn.Linear(10, 10, bias=False)
        self.c = nn.Linear(5, 5, bias=False)

    def forward(self, x):
        a = self.a(x)
        b = self.b(x)
        return (a, b)


class DictOutputModule(nn.Module):
    def __init__(self) -> None:
        super().__init__()
        self.module = UnusedParamTwoLinLayerNet()

    def forward(self, x):
        predictions = self.module(x)
        loss = (predictions[0] + predictions[1]).sum()
        return {
            "predictions": predictions,
            "loss": loss,
        }


class TwoLinLayerNet(nn.Module):
    def __init__(self) -> None:
        super().__init__()
        self.a = nn.Linear(10, 10, bias=False)
        self.b = nn.Linear(10, 1, bias=False)

    def forward(self, x):
        a = self.a(x)
        b = self.b(x)
        return (a, b)


class EmbeddingNetDifferentParams(nn.Module):
    """
    A module containing an embedding with different dimension or different # of
    parameters depending on the rank.
    """

    def __init__(self, rank, diff_num_params=False):
        super().__init__()
        embedding_dim = 500 if diff_num_params or rank == 0 else 50
        self.embedding = nn.Embedding(num_embeddings=10, embedding_dim=embedding_dim)
        self.lin = nn.Linear(embedding_dim, 1)
        if diff_num_params:
            self.lin2 = nn.Linear(1, 1, bias=False)

    def forward(self, x):
        x = self.embedding(x)
        return self.lin(x)


class ControlFlowToyModel(nn.Module):
    def __init__(self) -> None:
        super().__init__()
        self.lin1 = nn.Linear(10, 10, bias=False)
        self.lin2 = nn.Linear(10, 10, bias=False)

    def forward(self, x):
        # Second layer is used dependent on input x.
        use_second_layer = torch.equal(x, torch.ones(20, 10, device=x.device))
        if use_second_layer:
            return self.lin2(F.relu(self.lin1(x)))
        else:
            return F.relu(self.lin1(x))


DDP_NET = Net()
BN_NET = BatchNormNet()
BN_NET_NO_AFFINE = BatchNormNet(affine=False)
ONLY_SBN_NET = nn.SyncBatchNorm(2, momentum=0.99)


def get_timeout(test_id):
    test_name = test_id.split(".")[-1]
    if test_name in CUSTOMIZED_TIMEOUT:
        return CUSTOMIZED_TIMEOUT[test_name]
    else:
        return DEFAULT_TIMEOUT


default_pg_timeout = 60

CUSTOM_PG_TIMEOUT = {
    # This test runs slowly and needs additional time to complete, otherwise can
    # be taken down by TORCH_NCCL_ASYNC_ERROR_HANDLING
    "test_ddp_uneven_inputs": 300,
    # This test has a short timeout since it tests being taken down by
    # TORCH_NCCL_ASYNC_ERROR_HANDLING which we want to happen quickly.
    "test_ddp_model_diff_across_ranks": 5,
    # This test has a short timeout since it tests being taken down by
    # TORCH_NCCL_ASYNC_ERROR_HANDLING which we want to happen quickly.
    "test_ddp_has_finalized": 5,
}

def require_backend_is_available(backends):
    def check(backend):
        if backend == dist.Backend.GLOO:
            return dist.is_gloo_available()
        if backend == dist.Backend.NCCL:
            return dist.is_nccl_available()
        if backend == dist.Backend.MPI:
            return dist.is_mpi_available()
        if backend == dist.Backend.UCC:
            return dist.is_ucc_available()
        if backend in DistTestCases.backend_feature["plugin"]:
            return True
        return False

    if BACKEND not in backends:
        return skip_but_pass_in_sandcastle(
            f"Test requires backend {BACKEND} to be one of {backends}"
        )

    if not check(dist.Backend(BACKEND)):
        return skip_but_pass_in_sandcastle(
            f"Test requires backend {BACKEND} to be available"
        )
    return lambda func: func


def require_world_size(world_size):
    if int(os.environ["WORLD_SIZE"]) < world_size:
        return skip_but_pass_in_sandcastle(
            "Test requires world size of %d" % world_size
        )
    return lambda func: func


@contextmanager
def _lock():
    TEMP_DIR = os.environ["TEMP_DIR"]
    lockfile = os.path.join(TEMP_DIR, "lockfile")
    with open(lockfile, "w") as lf:
        try:
            if sys.platform == "win32":
                msvcrt.locking(lf.fileno(), msvcrt.LK_RLCK, 1)
                yield
            else:
                fcntl.flock(lf.fileno(), fcntl.LOCK_EX)
                yield
        finally:
            if sys.platform == "win32":
                msvcrt.locking(lf.fileno(), msvcrt.LK_UNLCK, 1)
            else:
                fcntl.flock(lf.fileno(), fcntl.LOCK_UN)
            lf.close()


@contextmanager
def _rank_temp_file():
    if dist.get_rank() == 0:
        fd, name = tempfile.mkstemp()
        os.close(fd)
    else:
        name = None
    object_list = [name]
    dist.broadcast_object_list(object_list)
    name = object_list[0]
    try:
        yield name
    finally:
        if dist.get_rank() == 0:
            os.remove(name)


def _build_tensor(size, value=None, dtype=torch.float, device_id=None):
    if value is None:
        value = size
    if device_id is None:
        return torch.empty(size, size, size, dtype=dtype).fill_(value)
    else:
        return torch.empty(size, size, size, dtype=dtype).fill_(value).cuda(device_id)


def _build_multidim_tensor(dim, dim_size, value=None, dtype=torch.float):
    if value is None:
        value = dim
    return torch.empty(size=[dim_size for _ in range(dim)], dtype=dtype).fill_(value)


def _create_autograd_profiler():
    return torch.autograd.profiler.profile(record_shapes=True)


def _create_torch_profiler():
    return torch.profiler.profile(
        activities=[
            torch.profiler.ProfilerActivity.CPU,
        ],
        record_shapes=True,
    )


class Barrier:
    barrier_id = 0

    @classmethod
    def init(cls):
        cls.barrier_id = 0
        barrier_dir = os.path.join(os.environ["TEMP_DIR"], "barrier")
        for f_name in os.listdir(barrier_dir):
            os.unlink(os.path.join(barrier_dir, f_name))

    @classmethod
    def sync(cls, wait_for=None, timeout=10):
        if wait_for is None:
            wait_for = dist.get_world_size()
        cls.barrier_id += 1
        barrier_dir = os.path.join(os.environ["TEMP_DIR"], "barrier")
        pid = str(os.getpid())
        barrier_file = os.path.join(barrier_dir, pid)
        with _lock():
            with open(barrier_file, "w") as f:
                f.write(str(cls.barrier_id))

        start_time = time.time()
        while True:
            arrived = 0
            with _lock():
                for f_name in os.listdir(barrier_dir):
                    with open(os.path.join(barrier_dir, f_name)) as f:
                        data = f.read()
                        if int(data) >= cls.barrier_id:
                            arrived += 1
            if arrived == wait_for:
                break

            if time.time() - start_time > timeout:
                raise RuntimeError("barrier timeout")
            time.sleep(0.1)


class TestDistBackend(MultiProcessTestCase):
    @classmethod
    def setUpClass(cls):
        os.environ["MASTER_ADDR"] = str(MASTER_ADDR)
        # Not setting MASTER_PORT and get a random free port
        super().setUpClass()

    def setUp(self):
        super().setUp()
        # initialize temp directories
        initialize_temp_directories()
        # initialize Barrier
        Barrier.init()
        # Skip return code checking for following tests as they are expected to
        # crash a process due to TORCH_NCCL_ASYNC_ERROR_HANDLING.
        self.skip_return_code_checks = [self.test_ddp_has_finalized.__wrapped__]

    def tearDown(self):
        cleanup_temp_dir()
        super().tearDown()

    @property
    def init_method(self):
        return f"{FILE_SCHEMA}{self.file_name}"

    @property
    def destroy_pg_upon_exit(self) -> bool:
        # Overriding base test class: do not auto destroy PG upon exit.
        return False

    @classmethod
    def _run(cls, rank, test_name, file_name, pipe, **kwargs):
        if BACKEND == "nccl" and not torch.cuda.is_available():
            sys.exit(TEST_SKIPS["no_cuda"].exit_code)
        self = cls(test_name)
        self.rank = rank
        self.file_name = file_name

        if torch.cuda.is_available() and torch.cuda.device_count() < int(
            self.world_size
        ):
            sys.exit(TEST_SKIPS[f"multi-gpu-{self.world_size}"].exit_code)
        try:
            pg_timeout_seconds = CUSTOM_PG_TIMEOUT.get(test_name, default_pg_timeout)
            timeout = timedelta(seconds=pg_timeout_seconds)
            dist.init_process_group(
                init_method=self.init_method,
                backend=BACKEND,
                world_size=int(self.world_size),
                rank=self.rank,
                timeout=timeout,
            )
        except RuntimeError as e:
            if "recompile" in e.args[0]:
                sys.exit(TEST_SKIPS["backend_unavailable"].exit_code)

            raise

        # Execute barrier prior to running test to ensure that every process
        # has finished initialization and that the following test
        # immediately exiting due to a skip doesn't cause flakiness.
        self._barrier()

        self.run_test(test_name, pipe)
        self._barrier()
        dist.destroy_process_group()
        sys.exit(0)

    # Needed since MultiProcessTestCase assumes a world_size of 4, but we
    # run these tests under other various world_sizes.
    @property
    def world_size(self):
        return os.environ["WORLD_SIZE"]


class DistributedTest:
    class _DistTestBase:
        def _barrier(self, *args, **kwargs):
            Barrier.sync(*args, **kwargs)

        def _init_group_test(self, **kwargs):
            group = [1, 2]
            group_id = dist.new_group(group, **kwargs)
            rank = dist.get_rank()
            if rank not in group:
                return ([], None, rank)

            return (group, group_id, rank)

        def _init_full_group_test(self, **kwargs):
            group = list(range(0, dist.get_world_size()))
            group_id = dist.new_group(**kwargs)
            rank = dist.get_rank()
            return (group, group_id, rank)

        def _init_global_test(self):
            group = list(range(0, dist.get_world_size()))
            group_id = dist.group.WORLD
            rank = dist.get_rank()
            return (group, group_id, rank)

        def _verify_buffers_equal(self, m1, m2):
            # verify buffers across models
            m1_buf_dict = dict(m1.module.named_buffers())
            for name, buf in m2.module.named_buffers():
                self.assertEqual(buf, m1_buf_dict[name])

            # Verify buffers across ranks.
            m1_buffers = list(m1.buffers())
            m2_buffers = list(m2.buffers())
            for (buf1, buf2) in zip(m1_buffers, m2_buffers):
                gathered_bufs = [
                    torch.empty_like(buf1) for _ in range(dist.get_world_size())
                ]
                dist.all_gather(gathered_bufs, buf1)
                gathered_bufs_m2 = [
                    torch.empty_like(buf2) for _ in range(dist.get_world_size())
                ]
                for b in gathered_bufs:
                    self.assertEqual(b, buf1)
                dist.all_gather(gathered_bufs_m2, buf2)
                for b in gathered_bufs_m2:
                    self.assertEqual(b, buf2)

        def _sanity_check_profiler_nccl_meta(self, nccl_meta_events):
            """Torch profiler includes nccl metadata in an inserted operator called "record_param_comms"
            We test for basic fields in this profiler event that correspond to the nccl communication
            collectives"""
            per_coll_meta = defaultdict(list)
            for e in nccl_meta_events:
                args = e.get("args", {})
                collname = args.get("Collective name", "")
                self.assertNotEqual(collname, "")
                self.assertNotEqual(args.get("dtype", ""), "")

                per_coll_meta[collname].append(args)
                if collname in {"wait"}:
                    continue

                self.assertEqual(args["Process Group Description"], "default_pg")
                self.assertNotEqual(args["Process Group Ranks"], "")

                self.assertGreaterEqual(args.get("In msg nelems", -1), 0)
                self.assertGreaterEqual(args.get("Out msg nelems", -1), 0)
                self.assertGreaterEqual(args.get("Group size", -1), 0)
                self.assertGreaterEqual(args.get("Global rank start", -1), 0)
                self.assertGreaterEqual(args.get("Global rank stride", -1), 0)

            # print(per_coll_meta)
            return per_coll_meta

        def test_dump_DDP_relevant_env_vars(self):
            with captured_output() as (out, _):
                _dump_DDP_relevant_env_vars()
                lines = out.getvalue().splitlines()

            def format_line(var):
                return f"env:{var}={os.environ[var] if var in os.environ else 'N/A'}"

            # Check relevant env vars
            vars = [
                "MASTER_ADDR",
                "MASTER_PORT",
                "WORLD_SIZE",
                "NCCL_TOPO_DUMP_FILE",  # N/A
                "TORCH_NCCL_ASYNC_ERROR_HANDLING",
            ]
            for var in vars:
                line = format_line(var)
                self.assertIn(line, lines)
            # Check irrelevant env vars
            vars = [
                "xxx",
                "yyy",
                "zzz",
            ]
            for var in vars:
                line = format_line(var)
                self.assertNotIn(line, lines)

        # GET RANK
        def test_get_rank(self):
            test_dir = os.path.join(os.environ["TEMP_DIR"], "test_dir")
            pid = str(os.getpid())
            num_processes = dist.get_world_size()
            with open(os.path.join(test_dir, pid), "w") as f:
                f.write(str(dist.get_rank()))

            self._barrier()

            all_ranks = set()
            for f_name in os.listdir(test_dir):
                with open(os.path.join(test_dir, f_name)) as f:
                    all_ranks.add(int(f.read()))
            self.assertEqual(len(all_ranks), num_processes)

            self._barrier()

            if dist.get_rank() == 0:
                for f_name in os.listdir(test_dir):
                    os.unlink(os.path.join(test_dir, f_name))

            self._barrier()

        def test_get_backend(self):
            if dist.get_world_size() > 2:
                group = [1, 2]
            else:
                group = [0, 1]
            group_id = dist.new_group(group)
            backend_str = BACKEND.lower()
            self.assertEqual(dist.get_backend(), backend_str)
            if dist.get_rank() in group:
                self.assertEqual(dist.get_backend(group_id), backend_str)
            else:
                with self.assertRaisesRegex(
                    ValueError, "Invalid process group specified"
                ):
                    dist.get_backend(group_id)

        def test_Backend_enum_class(self):
            # test parsing
            backend = BACKEND.lower()
            self.assertEqual(dist.Backend(BACKEND.upper()), backend)
            self.assertEqual(dist.Backend(BACKEND), backend)
            with self.assertRaises(ValueError):
                dist.Backend(None)
            with self.assertRaises(ValueError):
                dist.Backend(3)
            with self.assertRaises(ValueError):
                dist.Backend(["gloo"])

        # Test destroy
        def test_destroy_group(self):
            if dist.get_world_size() > 2:
                group = [1, 2]
            else:
                group = [0, 1]
            group_id = dist.new_group(group)
            self._barrier()
            dist.destroy_process_group(group_id)

        # Test get rank and size of group
        def test_get_rank_size_group(self):
            if dist.get_world_size() > 2:
                group = [1, 2]
            else:
                group = [0, 1]
            group_id = dist.new_group(group)
            if dist.get_rank() in group:
                self.assertEqual(dist.get_world_size(group_id), 2)
                self.assertTrue(dist.get_rank(group_id) in list(range(2)))
            else:
                self.assertEqual(dist.get_world_size(group_id), -1)
                self.assertEqual(dist.get_rank(group_id), -1)

        # Test destroy full groups
        def test_destroy_full_group(self):
            _, group_id, _ = self._init_full_group_test()
            self._barrier()
            dist.destroy_process_group(group_id)

        # Test get rank and size of full group
        def test_get_rank_size_full_group(self):
            _, group_id, _ = self._init_full_group_test()
            self.assertEqual(dist.get_world_size(group_id), dist.get_world_size())
            self.assertEqual(dist.get_rank(group_id), dist.get_rank())

        def _test_barrier_timeout(self, group_id, timeout):
            local_rank = dist.get_rank(group_id)

            # Only execute barrier on rank == 0, causing it to timeout
            if local_rank == 0:
                expected_time = time.time() + timeout.total_seconds()
                # In debug mode, we execute a monitored_barrier before the
                # collective, so assert on that.
                if dist.get_debug_level() == dist.DebugLevel.DETAIL:
                    exception_ctx = self.assertRaisesRegex(
                        Exception, "failed to pass monitoredBarrier"
                    )
                else:
                    exception_ctx = self.assertRaisesRegex(
                        Exception, " (Timed out|closed|timeout) "
                    )
                with exception_ctx:
                    dist.barrier(group_id)
                self.assertGreaterAlmostEqual(time.time(), expected_time, delta=0.1)
            else:
                pass

        @skip_but_pass_in_sandcastle_if(
            BACKEND != "gloo", "Only gloo backend supports timeouts"
        )
        @skip_but_pass_in_sandcastle_if(
            not INIT_METHOD.startswith("file://"),
            "Requires file:// initialization method. "
            + "Both tcp:// and env:// rely on the TCP store for which "
            "reinitialization has proven racy.",
        )
        def test_barrier_timeout_global(self):
            dist.destroy_process_group()

            # Explicitly pass world size to the barrier because we've
            # just destroyed any state in torch.distributed.
            self._barrier(wait_for=int(os.environ["WORLD_SIZE"]))

            # Reinitialize global process group
            timeout = timedelta(seconds=1)
            dist.init_process_group(
                init_method=INIT_METHOD,
                backend=BACKEND,
                world_size=int(os.environ["WORLD_SIZE"]),
                rank=self.rank,
                timeout=timeout,
            )
            self._test_barrier_timeout(dist.group.WORLD, timeout)

        @skip_if_small_worldsize
        @skip_but_pass_in_sandcastle_if(
            BACKEND != "gloo", "Only gloo backend supports timeouts"
        )
        def test_barrier_timeout_group(self):
            timeout = timedelta(seconds=5)
            _, group_id, _ = self._init_group_test(timeout=timeout)
            if group_id is not None:
                self._test_barrier_timeout(group_id, timeout)

        @skip_but_pass_in_sandcastle_if(
            BACKEND != "gloo", "Only gloo backend supports timeouts"
        )
        def test_barrier_timeout_full_group(self):
            timeout = timedelta(seconds=1)
            _, group_id, _ = self._init_full_group_test(timeout=timeout)
            if group_id is not None:
                self._test_barrier_timeout(group_id, timeout)

        @skip_but_pass_in_sandcastle_if(
            BACKEND not in DistTestCases.backend_feature["subgroup"],
            f"The {BACKEND} backend does not support creating subgroups on CUDA devices",
        )
        @require_world_size(4)
        @skip_if_lt_x_gpu(2)
        def test_new_subgroups(self):
            subgroup_size = 2
            cur_subgroup, subgroups = dist.new_subgroups(subgroup_size)

            world_size = dist.get_world_size()
            self.assertEqual(cur_subgroup.size(), subgroup_size)
            self.assertEqual(len(subgroups), world_size / subgroup_size)
            self.assertFalse(dist._rank_not_in_group(cur_subgroup))

            for subgroup in subgroups:
                dist.destroy_process_group(subgroup)

        @skip_but_pass_in_sandcastle_if(
            BACKEND not in DistTestCases.backend_feature["subgroup"],
            f"The {BACKEND} backend does not support creating subgroups on CUDA devices",
        )
        @skip_if_no_gpu
        def test_new_subgroups_group_size_exceeds_world_size(self):
            with self.assertRaisesRegex(ValueError, "must not exceed"):
                dist.new_subgroups(100)

        @skip_but_pass_in_sandcastle_if(
            BACKEND not in DistTestCases.backend_feature["subgroup"],
            f"The {BACKEND} backend does not support creating subgroups on CUDA devices",
        )
        @require_world_size(4)
        @skip_if_lt_x_gpu(4)
        def test_new_subgroups_world_size_not_divisible_by_group_size(self):
            with self.assertRaisesRegex(
                ValueError, "The world size must be divisible by 'group_size'"
            ):
                dist.new_subgroups(3)

        @skip_but_pass_in_sandcastle_if(
            BACKEND not in DistTestCases.backend_feature["subgroup"],
            f"The {BACKEND} backend does not support creating subgroups on CUDA devices",
        )
        @require_world_size(4)
        @skip_if_lt_x_gpu(4)
        def test_new_subgroups_by_enumeration(self):
            _group, _group_id, rank = self._init_global_test()
            rank_to_GPU = init_multigpu_helper(dist.get_world_size(), BACKEND)
            device_id = rank_to_GPU[rank][0]
            cur_subgroup, subgroups = dist.new_subgroups_by_enumeration(
                ranks_per_subgroup_list=[[0, 2], [1, 3]]
            )
            if device_id >= 4:
                self.assertIsNone(cur_subgroup)
            else:
                self.assertEqual(cur_subgroup.size(), 2)
                self.assertEqual(len(subgroups), 2)
                if device_id == 0 or device_id == 2:
                    self.assertEqual(cur_subgroup, subgroups[0])
                else:
                    self.assertEqual(cur_subgroup, subgroups[1])

            for subgroup in subgroups:
                dist.destroy_process_group(subgroup)

        @skip_but_pass_in_sandcastle_if(
            BACKEND not in DistTestCases.backend_feature["subgroup"],
            f"The {BACKEND} backend does not support creating subgroups on CUDA devices",
        )
        @require_world_size(4)
        @skip_if_lt_x_gpu(4)
        def test_new_subgroups_by_enumeration_input_rank_exceeds_world_size(self):
            _group, group_id, _rank = self._init_global_test()
            init_multigpu_helper(dist.get_world_size(), BACKEND)
            world_size = get_world_size(group_id)

            with self.assertRaisesRegex(
                ValueError,
                "The new group's rank should be within the world_size set by init_process_group",
            ):
                dist.new_subgroups_by_enumeration(
                    ranks_per_subgroup_list=[[0, 1], [world_size, 2]]
                )

        @skip_but_pass_in_sandcastle_if(
            BACKEND not in DistTestCases.backend_feature["subgroup"],
            f"The {BACKEND} backend does not support creating subgroups on CUDA devices",
        )
        @skip_if_no_gpu
        def test_new_subgroups_by_enumeration_negative_input_rank(self):
            self._init_global_test()

            with self.assertRaisesRegex(
                ValueError,
                "The new group's rank should be within the world_size set by init_process_group",
            ):
                dist.new_subgroups_by_enumeration(
                    ranks_per_subgroup_list=[[-1, -2], [-3, -4]]
                )

        @skip_but_pass_in_sandcastle_if(
            BACKEND not in DistTestCases.backend_feature["subgroup"],
            f"The {BACKEND} backend does not support creating subgroups on CUDA devices",
        )
        @require_world_size(4)
        @skip_if_lt_x_gpu(4)
        def test_new_subgroups_overlap_not_allowed(self):
            with self.assertRaisesRegex(
                ValueError, "Rank 1 has appeared in both subgroup"
            ):
                dist.new_subgroups_by_enumeration(
                    ranks_per_subgroup_list=[[0], [1, 2], [1, 3]]
                )

        @skip_but_pass_in_sandcastle_if(
            BACKEND not in DistTestCases.backend_feature["subgroup"],
            f"The {BACKEND} backend does not support creating subgroups on CUDA devices",
        )
        @skip_if_lt_x_gpu(2)
        def test_average_parameters(self):
            rank = dist.get_rank()
            rank_to_GPU = init_multigpu_helper(dist.get_world_size(), BACKEND)
            device_id = rank_to_GPU[rank][0]

            model = nn.Sequential(
                nn.Conv2d(3, 3, kernel_size=3, padding=1),
                nn.ReLU(),
                nn.Linear(1, 5, bias=False),
            ).cuda(device_id)
            # Test global model averaging
            for p in model.parameters():
                p.data = torch.ones_like(p.data)
            model_averaging_utils.average_parameters(
                params=model.parameters(), process_group=None
            )
            # Every element will be the same as the input.
            for p in model.parameters():
                self.assertEqual(p.data, torch.ones_like(p.data))

            # Test partial model averaging
            for p in model.parameters():
                p.data = torch.ones_like(p.data) * rank
            group_nccl = dist.new_group(ranks=[0, 1], backend="nccl")
            model_averaging_utils.average_parameters(
                params=model.parameters(), process_group=group_nccl
            )
            if not dist._rank_not_in_group(group_nccl):
                # Every element on device 0 or 1 should be the average of 0 and 1, i.e., 0.5.
                for p in model.parameters():
                    self.assertEqual(p.data, torch.ones_like(p.data) * 0.5)
            else:
                # Every element on device not in the subgroup should remain the same.
                for p in model.parameters():
                    self.assertEqual(p.data, torch.ones_like(p.data) * rank)

        @skip_but_pass_in_sandcastle_if(
            BACKEND not in DistTestCases.backend_feature["subgroup"],
            f"The {BACKEND} backend does not support creating subgroups on CUDA devices",
        )
        @skip_if_lt_x_gpu(2)
        def test_periodic_model_averager(self):
            rank = dist.get_rank()
            world_size = dist.get_world_size()
            rank_to_GPU = init_multigpu_helper(world_size, BACKEND)
            device_id = rank_to_GPU[rank][0]

            model = nn.Linear(1, 5, bias=False).cuda(device_id)
            param = next(model.parameters())
            tensor = torch.ones_like(param.data) * rank
            expected_avg_tensor = (
                torch.ones_like(param.data) * sum(range(world_size)) / world_size
            )
            period = 4
            for warmup_steps in [12, 13, 14, 15]:
                averager = averagers.PeriodicModelAverager(
                    period=period, warmup_steps=warmup_steps
                )
                for step in range(0, 20):
                    # Reset the parameters at every step.
                    param.data = copy.deepcopy(tensor)
                    for params in model.parameters():
                        # mock grad
                        params.grad = torch.ones_like(param.data)
                    averager.average_parameters(model.parameters())
                    if step >= warmup_steps and (step - warmup_steps) % period == 0:
                        self.assertEqual(param.data, expected_avg_tensor)
                    else:
                        # No model averaging, so the parameters are not updated.
                        self.assertEqual(param.data, tensor)

        @skip_if_lt_x_gpu(2)
        def test_periodic_model_averager_param_group(self):
            rank = dist.get_rank()
            world_size = dist.get_world_size()
            rank_to_GPU = init_multigpu_helper(world_size, BACKEND)
            device_id = rank_to_GPU[rank][0]

            model = nn.Linear(1, 5, bias=False).cuda(device_id)
            param = next(model.parameters())
            opt = torch.optim.SGD(model.parameters(), lr=0.1)

            period = 4
            for warmup_steps in [12, 13, 14, 15]:
                averager = averagers.PeriodicModelAverager(
                    period=period, warmup_steps=warmup_steps
                )
                for step in range(0, 20):
                    # Reset the parameters at every step.
                    for param_group in opt.param_groups:
                        for params in param_group["params"]:
                            # mock grad
                            params.grad = torch.ones_like(param.data) * rank
                            params.data = torch.ones_like(param.data) * rank
                    averager.average_parameters(opt.param_groups)
                    if step >= warmup_steps and (step - warmup_steps) % period == 0:
                        for param_group in opt.param_groups:
                            for params in param_group["params"]:
                                if params.grad is None:
                                    continue
                                self.assertEqual(
                                    param.data,
                                    torch.ones_like(param.data)
                                    * sum(range(world_size))
                                    / world_size,
                                )
                    else:
                        # No model averaging, so the parameters are not updated.
                        for param_group in opt.param_groups:
                            for params in param_group["params"]:
                                if params.grad is None:
                                    continue
                                self.assertEqual(
                                    param.data, torch.ones_like(param.data) * rank
                                )

        @skip_but_pass_in_sandcastle_if(
            BACKEND not in DistTestCases.backend_feature["subgroup"],
            f"The {BACKEND} backend does not support creating subgroups on CUDA devices",
        )
        @skip_if_lt_x_gpu(2)
        def test_1_level_hierarchical_model_averager_equivalent_to_periodic_model_averager(
            self,
        ):
            rank = dist.get_rank()
            world_size = dist.get_world_size()
            rank_to_GPU = init_multigpu_helper(world_size, BACKEND)
            device_id = rank_to_GPU[rank][0]

            model = nn.Linear(1, 5, bias=False).cuda(device_id)
            param = next(model.parameters())
            tensor = torch.ones_like(param.data) * rank
            expected_avg_tensor = (
                torch.ones_like(param.data) * sum(range(world_size)) / world_size
            )
            period = 4
            for warmup_steps in [12, 13, 14, 15]:
                averager = hierarchicalSGD.HierarchicalModelAverager(
                    # Run the global averaging at a period of 4,
                    # which is equivalent to the above periodic model averaging test case.
                    period_group_size_dict=OrderedDict([(period, world_size)]),
                    warmup_steps=warmup_steps,
                )

                averager = averagers.PeriodicModelAverager(
                    period=period, warmup_steps=warmup_steps
                )
                for step in range(0, 20):
                    # Reset the parameters at every step.
                    param.data = copy.deepcopy(tensor)
                    for params in model.parameters():
                        # mock grad
                        params.grad = torch.ones_like(param.data)
                    averager.average_parameters(model.parameters())
                    if step >= warmup_steps and (step - warmup_steps) % period == 0:
                        self.assertEqual(param.data, expected_avg_tensor)
                    else:
                        # No model averaging, so the parameters are not updated.
                        self.assertEqual(param.data, tensor)

        @skip_but_pass_in_sandcastle_if(
            BACKEND not in DistTestCases.backend_feature["subgroup"],
            f"The {BACKEND} backend does not support creating subgroups on CUDA devices",
        )
        @require_world_size(4)
        @skip_if_lt_x_gpu(4)
        def test_3_level_hierarchical_model_averager(self):
            rank = dist.get_rank()
            world_size = dist.get_world_size()
            rank_to_GPU = init_multigpu_helper(world_size, BACKEND)
            device_id = rank_to_GPU[rank][0]

            model = nn.Linear(1, 5, bias=False).cuda(device_id)
            param = next(model.parameters())
            tensor = torch.ones_like(param.data) * rank
            # Set up such a hierarchical model averaging as follows:
            # after the first 10 warmup steps,
            # run model averaging every 2 steps within each subgroup of size 2,
            # run model averaging every 4 steps within each subgroup of size 3,
            # and run the global model averaging every 8 steps.
            # If there is a conflict in model averaging at a step, only run the highest-level model averaging.
            warmup_steps = 10
            subgroup_size1 = 2
            subgroup_avg_period1 = 2
            subgroup_size2 = 4
            subgroup_avg_period2 = 4
            global_avg_period = 8
            period_group_size_dict = OrderedDict(
                [
                    (subgroup_avg_period1, subgroup_size1),
                    (subgroup_avg_period2, subgroup_size2),
                    (global_avg_period, world_size),
                ]
            )
            averager = hierarchicalSGD.HierarchicalModelAverager(
                period_group_size_dict=period_group_size_dict, warmup_steps=warmup_steps
            )
            self.assertEqual(dist.get_pg_count(), len(period_group_size_dict))

            subgroup1 = averager.period_process_group_dict[subgroup_avg_period1]
            subgroup2 = averager.period_process_group_dict[subgroup_avg_period2]
            real_group_ranks_res1 = _get_pg_config(subgroup1)['ranks']
            real_group_ranks_res2 = _get_pg_config(subgroup2)['ranks']

            expect_group_ranks_res1 = (
                rank // subgroup_size1 * subgroup_size1
                + np.array(list(range(subgroup_size1)))
            ).tolist()
            expect_group_ranks_res2 = (
                rank // subgroup_size2 * subgroup_size2
                + np.array(list(range(subgroup_size2)))
            ).tolist()
            self.assertEqual(real_group_ranks_res1, expect_group_ranks_res1)
            self.assertEqual(real_group_ranks_res2, expect_group_ranks_res2)

            expected_avg_tensor_within_subgroup1 = (
                torch.ones_like(param.data)
                * sum(real_group_ranks_res1)
                / subgroup_size1
            )
            expected_avg_tensor_within_subgroup2 = (
                torch.ones_like(param.data)
                * sum(real_group_ranks_res2)
                / subgroup_size2
            )
            expected_global_avg_tensor = (
                torch.ones_like(param.data) * sum(range(world_size)) / world_size
            )
            for step in range(0, 25):
                # Reset the parameters at every step.
                param.data = copy.deepcopy(tensor)
                for params in model.parameters():
                    # mock grad
                    params.grad = torch.ones_like(param.data)
                averager.average_parameters(model.parameters())
                if step == 16 or step == 24:
                    # Run global model averaging when `step` can be divided by 8.
                    self.assertEqual(param.data, expected_global_avg_tensor)
                elif step == 12 or step == 20:
                    # Run model averaging within subgroup when `step` can be divided by 4 but not by 8.
                    self.assertEqual(param.data, expected_avg_tensor_within_subgroup2)
                elif step == 10 or step == 14 or step == 18 or step == 22:
                    # Run model averaging within subgroup when `step` can be divided by 2 but not by 4 or 8.
                    self.assertEqual(param.data, expected_avg_tensor_within_subgroup1)
                else:
                    # No model averaging, so the parameters are not updated.
                    self.assertEqual(param.data, tensor)

        # Coalescing manager (sync mode)
        @skip_if_no_gpu
        @skip_but_pass_in_sandcastle_if(
            BACKEND != "nccl" or IS_FBCODE or IS_SANDCASTLE,
            "Coalescing manager currently tests with NCCL only; internal test flaky"
        )
        def test_coalescing_manager(self):
            self._barrier()
            rank = dist.get_rank()
            world_size = dist.get_world_size()
            rank_to_GPU = init_multigpu_helper(world_size, BACKEND)
            device_id = rank_to_GPU[rank][0]
            torch.cuda.set_device(device_id)
            num_colls = 2
            size_per_coll = 8
            small_tensors = [
                torch.ones(size_per_coll, device=device_id) for _ in range(num_colls)
            ]

            with dist._coalescing_manager():
                for i in range(num_colls):
                    dist.all_reduce(small_tensors[i])

            big_tensor = torch.ones(num_colls * size_per_coll, device=device_id)
            dist.all_reduce(big_tensor)

            for i in range(num_colls):
                self.assertEqual(
                    small_tensors[i],
                    big_tensor[i * size_per_coll : (i + 1) * size_per_coll]
                )

            self._barrier()

        # Coalescing manager (async mode)
        @skip_if_no_gpu
        @skip_but_pass_in_sandcastle_if(
            BACKEND != "nccl" or IS_FBCODE or IS_SANDCASTLE,
            "Coalescing manager currently tests with NCCL only; internal test flaky"
        )
        def test_coalescing_manager_async(self):
            self._barrier()
            rank = dist.get_rank()
            world_size = dist.get_world_size()
            rank_to_GPU = init_multigpu_helper(world_size, BACKEND)
            device_id = rank_to_GPU[rank][0]
            torch.cuda.set_device(device_id)
            num_colls = 2
            size_per_coll = 8
            small_tensors = [
                torch.ones(size_per_coll, device=device_id) for _ in range(num_colls)
            ]

            with dist._coalescing_manager(async_ops=True) as cm:
                for i in range(num_colls):
                    dist.all_reduce(small_tensors[i])
            cm.wait()

            big_tensor = torch.ones(num_colls * size_per_coll, device=device_id)
            dist.all_reduce(big_tensor)

            for i in range(num_colls):
                self.assertEqual(
                    small_tensors[i],
                    big_tensor[i * size_per_coll : (i + 1) * size_per_coll]
                )

            self._barrier()

        # NCCL Batch SEND RECV
        @skip_if_no_gpu
        @skip_but_pass_in_sandcastle_if(BACKEND != "nccl", "NCCL Batch Send Recv Only")
        @requires_nccl_version((2, 7, 0), "Need NCCL 2.7+ for send/recv")
        def test_batch_isend_irecv_nccl(self):
            self._barrier()
            rank = dist.get_rank()
            world_size = dist.get_world_size()
            rank_to_GPU = init_multigpu_helper(world_size, BACKEND)
            device_id = rank_to_GPU[rank][0]
            torch.cuda.set_device(device_id)
            p2p_op_list = []
            recv_tensors = [None for _ in range(world_size)]
            expected_tensors = [None for _ in range(world_size)]

            for val in ["1", "0"]:
                os.environ["TORCH_NCCL_BLOCKING_WAIT"] = val
                for src in range(0, world_size):
                    send_tensor = _build_tensor(rank + 1, device_id=device_id).fill_(
                        src
                    )
                    recv_tensors[src] = _build_tensor(
                        src + 1, value=-1, device_id=device_id
                    ).fill_(-1)
                    expected_tensors[src] = _build_tensor(
                        src + 1, value=-1, device_id=device_id
                    ).fill_(rank)
                    recv_op = dist.P2POp(dist.irecv, recv_tensors[src], src)
                    p2p_op_list.append(recv_op)
                    send_op = dist.P2POp(dist.isend, send_tensor, src)
                    p2p_op_list.append(send_op)

                reqs = dist.batch_isend_irecv(p2p_op_list)
                for req in reqs:
                    req.wait()

                for src in range(0, world_size):
                    self.assertEqual(recv_tensors[src], expected_tensors[src])

            self._barrier()

        @skip_if_no_gpu
        @skip_but_pass_in_sandcastle_if(BACKEND != "nccl", "NCCL Batch Send Recv Only")
        @requires_nccl_version((2, 7, 0), "Need NCCL 2.7+ for send/recv")
        def test_batch_isend_irecv_ring_exchange_nccl(self):
            self._barrier()
            rank = dist.get_rank()
            world_size = dist.get_world_size()
            rank_to_GPU = init_multigpu_helper(world_size, BACKEND)
            device_id = rank_to_GPU[rank][0]
            torch.cuda.set_device(device_id)

            send_tensor = _build_tensor(world_size, device_id=device_id)
            recv_tensor = _build_tensor(world_size, value=-1, device_id=device_id)
            send_op = dist.P2POp(dist.isend, send_tensor, (rank + 1) % world_size)
            recv_op = dist.P2POp(
                dist.irecv, recv_tensor, (rank - 1 + world_size) % world_size
            )
            reqs = dist.batch_isend_irecv([send_op, recv_op])
            for req in reqs:
                req.wait()

            self._barrier()

        @skip_if_no_gpu
        @skip_but_pass_in_sandcastle_if(BACKEND != "nccl", "NCCL Batch Send Recv Only")
        @requires_nccl_version((2, 7, 0), "Need NCCL 2.7+ for send/recv")
        def test_batch_isend_irecv_self_nccl(self):
            self._barrier()
            # Ensure the process group has been fully initialized (needed by
            # the first sub-group batch_isend_irecv call)
            dist.barrier()
            rank = dist.get_rank()
            rank_to_GPU = init_multigpu_helper(dist.get_world_size(), BACKEND)
            device_id = rank_to_GPU[rank][0]
            p2p_op_list = []

            if rank == 0:
                send_tensor = _build_tensor(rank + 1, device_id=device_id)
                recv_tensor = _build_tensor(rank + 1, value=-1, device_id=device_id)
                recv_op = dist.P2POp(dist.irecv, recv_tensor, 0)
                p2p_op_list.append(recv_op)
                send_op = dist.P2POp(dist.isend, send_tensor, 0)
                p2p_op_list.append(send_op)

                reqs = dist.batch_isend_irecv(p2p_op_list)
                for req in reqs:
                    req.wait()

            self._barrier()

        @skip_if_no_gpu
        @skip_if_small_worldsize
        @skip_but_pass_in_sandcastle_if(BACKEND != "nccl", "NCCL Batch Send Recv Only")
        @requires_nccl_version((2, 7, 0), "Need NCCL 2.7+ for send/recv")
        def test_batch_isend_irecv_no_rank_zero_nccl(self):
            self._barrier()
            # Ensure the process group has been fully initialized (needed by
            # the first sub-group batch_isend_irecv call)
            dist.barrier()
            rank = dist.get_rank()
            rank_to_GPU = init_multigpu_helper(dist.get_world_size(), BACKEND)
            device_id = rank_to_GPU[rank][0]
            torch.cuda.set_device(device_id)
            p2p_op_list = []

            if rank == 1:
                peer = 2
            elif rank == 2:
                peer = 1

            if rank in [1, 2]:
                send_tensor = _build_tensor(rank + 1, device_id=device_id)
                recv_tensor = _build_tensor(peer + 1, value=-1, device_id=device_id)
                recv_op = dist.P2POp(dist.irecv, recv_tensor, peer)
                p2p_op_list.append(recv_op)
                send_op = dist.P2POp(dist.isend, send_tensor, peer)
                p2p_op_list.append(send_op)

                reqs = dist.batch_isend_irecv(p2p_op_list)
                for req in reqs:
                    req.wait()

            self._barrier()

        # GLOO Batch SEND RECV CPU
        @skip_but_pass_in_sandcastle_if(BACKEND != "gloo", "GLOO Batch Send Recv CPU")
        def test_batch_isend_irecv_gloo(self):
            self._barrier()
            rank = dist.get_rank()
            p2p_op_list = []

            for src in range(0, dist.get_world_size()):
                if src == rank:
                    continue
                send_tensor = _build_tensor(rank + 1)
                recv_tensor = _build_tensor(src + 1, value=-1)
                recv_op = dist.P2POp(dist.irecv, recv_tensor, src)
                p2p_op_list.append(recv_op)
                send_op = dist.P2POp(dist.isend, send_tensor, src)
                p2p_op_list.append(send_op)

            reqs = dist.batch_isend_irecv(p2p_op_list)
            for req in reqs:
                req.wait()

            self._barrier()

        # GLOO Batch SEND RECV CPU with provided tags
        @skip_but_pass_in_sandcastle_if(BACKEND != "gloo", "GLOO Batch Send Recv CPU")
        def test_batch_isend_irecv_gloo_tags(self):
            self._barrier()
            rank = dist.get_rank()
            p2p_op_list = []

            for src in range(0, dist.get_world_size()):
                if src == rank:
                    continue
                send_tensor = _build_tensor(rank + 1)
                recv_tensor = _build_tensor(src + 1, value=-1)
                recv_op = dist.P2POp(dist.irecv, recv_tensor, src, tag=src)
                p2p_op_list.append(recv_op)
                send_op = dist.P2POp(dist.isend, send_tensor, src, tag=rank)
                p2p_op_list.append(send_op)

            reqs = dist.batch_isend_irecv(p2p_op_list)
            for req in reqs:
                req.wait()

            self._barrier()

        # NCCL Batch SEND RECV Op Error
        @skip_but_pass_in_sandcastle_if(BACKEND != "nccl", "NCCL Batch Send Recv Only")
        @requires_nccl_version((2, 7, 0), "Need NCCL 2.7+ for send/recv")
        def test_batch_isend_irecv_op_err(self):
            self._barrier()
            rank = dist.get_rank()
            if rank == 0:
                rank_to_GPU = init_multigpu_helper(dist.get_world_size(), BACKEND)
                device_id = rank_to_GPU[rank][0]
                with self.assertRaisesRegex(ValueError, "^Invalid ``op``"):
                    send_tensor = _build_tensor(rank + 1, device_id=device_id)
                    send_op = dist.P2POp(dist.broadcast, send_tensor, 1)
                    dist.batch_isend_irecv([send_op])

        # NCCL Batch SEND RECV p2p_op_list Error
        @skip_but_pass_in_sandcastle_if(BACKEND != "nccl", "NCCL Batch Send Recv Only")
        @requires_nccl_version((2, 7, 0), "Need NCCL 2.7+ for send/recv")
        def test_batch_isend_irecv_op_list_err(self):
            self._barrier()
            rank = dist.get_rank()
            if rank == 0:
                with self.assertRaisesRegex(ValueError, "^Invalid ``p2p_op_list``"):
                    dist.batch_isend_irecv([1, 2])

        # NCCL Batch SEND RECV Mixed Backend Error
        @skip_but_pass_in_sandcastle_if(BACKEND != "nccl", "NCCL Batch Send Recv Only")
        @requires_nccl_version((2, 7, 0), "Need NCCL 2.7+ for send/recv")
        def test_batch_isend_irecv_mixed_backend_err(self):
            self._barrier()
            rank = dist.get_rank()
            init_multigpu_helper(dist.get_world_size(), BACKEND)
            group_gloo = dist.new_group(ranks=[0, 1], backend="gloo")
            group_nccl = dist.new_group(ranks=[0, 1], backend="nccl")
            if rank == 0:
                with self.assertRaisesRegex(
                    ValueError, "All ops need to use the same group"
                ):
                    send_tensor = _build_tensor(rank + 1)
                    send_op_gloo = dist.P2POp(dist.isend, send_tensor, 1, group_gloo)
                    send_op_nccl = dist.P2POp(dist.isend, send_tensor, 1, group_nccl)
                    dist.batch_isend_irecv([send_op_gloo, send_op_nccl])

        # NCCL SEND RECV
        @skip_if_no_gpu
        @skip_but_pass_in_sandcastle_if(BACKEND != "nccl", "NCCL Send Recv Only")
        @requires_nccl_version((2, 7, 0), "Need NCCL 2.7+ for send/recv")
        def _test_send_recv_nccl(self, profiler_ctx=None):
            # TODO: now that nccl send/recv is supported, there does not seem to
            # be a need to have nccl send/recv be tested separately.
            rank = dist.get_rank()
            world_size = dist.get_world_size()
            rank_to_GPU = init_multigpu_helper(world_size, BACKEND)
            device_id = rank_to_GPU[rank][0]
            torch.cuda.set_device(device_id)

            tensor = _build_tensor(rank + 1, device_id=device_id)
            profiler_cls = profiler_ctx if profiler_ctx is not None else nullcontext()
            with profiler_cls as prof:
                for src in range(0, world_size):
                    if src == rank:
                        # Send mode
                        for dst in range(0, world_size):
                            if dst == rank:
                                continue
                            dist.send(tensor, dst)
                    else:
                        # Recv mode
                        expected_tensor = _build_tensor(src + 1)
                        output_tensor = _build_tensor(
                            src + 1, value=-1, device_id=device_id
                        )
                        dist.recv(output_tensor, src)
                        self.assertEqual(output_tensor, expected_tensor)

                self._barrier()

            if profiler_ctx is not None:
                backend = dist.get_backend()
                if backend in SEND_RECV_PROFILING_SUPPORTED_BACKENDS:
                    for event_name in [f"{backend}:send", f"{backend}:recv"]:
                        events = get_profiling_event(event_name, prof, dedup_gpu_user_annotation=True)
                        self.assertTrue(events)
                        # Event order is not deterministic, so simply assert their shape
                        # is found in the following list.
                        expected_shapes = [
                            [[rank + 1] * 3] for rank in range(dist.get_world_size())
                        ]
                        for event in events:
                            self.assertTrue(event.input_shapes in expected_shapes)


        @skip_if_no_gpu
        @skip_but_pass_in_sandcastle_if(BACKEND != "nccl", "NCCL Send Recv Only")
        @requires_nccl_version((2, 7, 0), "Need NCCL 2.7+ for send/recv")
        def test_send_recv_nccl(self):
            self._test_send_recv_nccl()

        @skip_if_no_gpu
        @skip_but_pass_in_sandcastle_if(BACKEND != "nccl", "NCCL Send Recv Only")
        @requires_nccl_version((2, 7, 0), "Need NCCL 2.7+ for send/recv")
        def test_send_recv_nccl_autograd_profiler(self):
            profiler_ctx = torch.autograd.profiler.profile(record_shapes=True)
            self._test_send_recv_nccl(profiler_ctx)

        @skip_if_no_gpu
        @skip_but_pass_in_sandcastle_if(BACKEND != "nccl", "NCCL Send Recv Only")
        @requires_nccl_version((2, 7, 0), "Need NCCL 2.7+ for send/recv")
        @skip_but_pass_in_sandcastle_if(IS_FBCODE, "Kineto in fbcode causes hang")
        @skip_but_pass_in_sandcastle_if(
            IS_MACOS or IS_WINDOWS,
            "torch.profiler not enabled for mac/windows: https://github.com/pytorch/pytorch/pull/56124",
        )
        def test_send_recv_nccl_torch_profiler(self):
            profiler_ctx = torch.profiler.profile(
                activities=[
                    torch.profiler.ProfilerActivity.CPU,
                    torch.profiler.ProfilerActivity.CUDA,
                ],
                record_shapes=True,
            )
            self._test_send_recv_nccl(profiler_ctx)

        # SEND RECV
        def _test_send_recv(self, profiler_ctx):
            rank = dist.get_rank()
            send_size = rank + 1
            tensor = _build_tensor(send_size)
            ctx = profiler_ctx if profiler_ctx is not None else nullcontext()
            with ctx as prof:
                for src in range(0, dist.get_world_size()):
                    if src == rank:
                        # Send mode
                        for dst in range(0, dist.get_world_size()):
                            if dst == rank:
                                continue
                            dist.send(tensor, dst)
                    else:
                        # Recv mode
                        recv_size = src + 1
                        expected_tensor = _build_tensor(recv_size)
                        output_tensor = _build_tensor(recv_size, value=-1)
                        dist.recv(output_tensor, src)
                        self.assertEqual(output_tensor, expected_tensor)

            if profiler_ctx is not None:
                backend = dist.get_backend()
                if backend in SEND_RECV_PROFILING_SUPPORTED_BACKENDS:
                    for event_name in [f"{backend}:send", f"{backend}:recv"]:
                        events = get_profiling_event(event_name, prof)
                        # Each rank sends/recvs from all other ranks.
                        event_count = sum(e.count for e in events)
                        expected_event_count = dist.get_world_size() - 1
                        self.assertEqual(event_count, expected_event_count)
                        # Event order is not deterministic, so simply assert their shape
                        # is found in the following list.
                        expected_shapes = [
                            [[rank + 1] * 3] for rank in range(dist.get_world_size())
                        ]
                        for event in events:
                            self.assertTrue(event.is_async)
                            self.assertTrue(event.input_shapes in expected_shapes)

        @skip_but_pass_in_sandcastle_if(
            BACKEND == "nccl", "Nccl send/recv tested by test_send_recv_nccl"
        )
        def test_send_recv(self):
            self._test_send_recv(profiler_ctx=None)

        @skip_but_pass_in_sandcastle_if(
            BACKEND == "nccl", "NCCL send/recv tested by test_send_recv_nccl"
        )
        def test_send_recv_autograd_profiler(self):
            autograd_profiler_ctx = _create_autograd_profiler()
            self._test_send_recv(profiler_ctx=autograd_profiler_ctx)

        @skip_but_pass_in_sandcastle_if(
            BACKEND == "nccl", "NCCL send/recv tested by test_send_recv_nccl"
        )
        @skip_but_pass_in_sandcastle_if(IS_FBCODE, "Kineto in fbcode causes hang")
        @skip_but_pass_in_sandcastle_if(
            IS_MACOS or IS_WINDOWS,
            "torch.profiler not enabled for mac/windows: https://github.com/pytorch/pytorch/pull/56124",
        )
        def test_send_recv_torch_profiler(self):
            torch_profiler_ctx = _create_torch_profiler()
            return self._test_send_recv(profiler_ctx=torch_profiler_ctx)

        # SEND RECV ANY SOURCE
        def _test_send_recv_any_source(self, profiler_ctx):
            rank = dist.get_rank()
            send_recv_size = 10
            tensor = _build_tensor(send_recv_size, value=rank)
            recv_ranks = []
            irecv_ranks = []

            ctx = profiler_ctx if profiler_ctx is not None else nullcontext()
            with ctx as prof:
                for dst in range(0, dist.get_world_size()):
                    if dst == rank:
                        # Recv mode
                        for dst in range(0, dist.get_world_size()):
                            if dst == rank:
                                continue

                            for recv in ["recv", "irecv"]:
                                output_tensor = _build_tensor(send_recv_size, value=-1)

                                if recv == "recv":
                                    sender = dist.recv(output_tensor)
                                    recv_ranks.append(sender)
                                elif recv == "irecv":
                                    work = dist.irecv(output_tensor)
                                    work.wait()
                                    sender = work._source_rank()
                                    irecv_ranks.append(sender)

                                # Assert the scalar value "sender" that should be
                                # equal to the rank of the sender is equal to all
                                # values in the received tensor.
                                self.assertTrue(output_tensor.eq(sender).all())
                    else:
                        # Send mode
                        dist.send(tensor, dst)  # recv
                        dist.send(tensor, dst)  # irecv

            if profiler_ctx is not None:
                backend = dist.get_backend()
                if backend in SEND_RECV_PROFILING_SUPPORTED_BACKENDS:
                    for event_name in [f"{backend}:send", f"{backend}:recvAnySource"]:
                        events = get_profiling_event(event_name, prof)
                        # Each rank sends/recvs from other rank twice.
                        self.assertEqual(
                            sum(event.count for event in events),
                            2 * (dist.get_world_size() - 1),
                        )
                        for event in events:
                            self.assertTrue(event.is_async)
                            self.assertEqual(event.input_shapes, [[send_recv_size] * 3])

                # Each rank would have 2 * (world_size - 1) sends, verify that
                # globally we receive the same amount on the other end.
                recv_ranks_tensor = torch.cat(
                    (torch.tensor(recv_ranks), torch.tensor(irecv_ranks)), 0
                )
                global_recv_ranks = [
                    torch.empty_like(recv_ranks_tensor)
                    for _ in range(dist.get_world_size())
                ]
                dist.all_gather(global_recv_ranks, recv_ranks_tensor)
                global_recv_ranks_list = []
                for tensor in global_recv_ranks:
                    global_recv_ranks_list += tensor.tolist()

                from itertools import groupby

                global_recv_ranks_list.sort()
                frequency = [
                    len(list(group)) for key, group in groupby(global_recv_ranks_list)
                ]
                self.assertEqual(dist.get_world_size(), len(frequency))
                self.assertEqual(
                    [2 * (dist.get_world_size() - 1)] * dist.get_world_size(), frequency
                )
                self._barrier()

        @skip_but_pass_in_sandcastle_if(
            BACKEND in DistTestCases.skip_collective["sendrecv anysource"],
            f"{BACKEND} does not support send/recv from any source",
        )
        def test_send_recv_any_source(self):
            self._test_send_recv_any_source(profiler_ctx=None)

        @skip_but_pass_in_sandcastle_if(
            BACKEND in DistTestCases.skip_collective["sendrecv anysource"],
            f"{BACKEND} does not support send/recv from any source",
        )
        def test_send_recv_any_source_autograd_profiler(self):
            autograd_profiler_ctx = _create_autograd_profiler()
            self._test_send_recv_any_source(profiler_ctx=autograd_profiler_ctx)

        @skip_but_pass_in_sandcastle_if(
            BACKEND in DistTestCases.skip_collective["sendrecv anysource"],
            f"{BACKEND} does not support send/recv from any source",
        )
        @skip_but_pass_in_sandcastle_if(IS_FBCODE, "Kineto in fbcode code causes hang")
        @skip_but_pass_in_sandcastle_if(
            IS_MACOS or IS_WINDOWS,
            "torch.profiler not enabled for mac/windows: https://github.com/pytorch/pytorch/pull/56124",
        )
        def test_send_recv_any_source_torch_profiler(self):
            torch_profiler_ctx = _create_torch_profiler()
            return self._test_send_recv_any_source(profiler_ctx=torch_profiler_ctx)

        # SEND RECV WITH TAG
        def _test_send_recv_with_tag(self, profiler_ctx):
            rank = dist.get_rank()
            world_size = dist.get_world_size()
            send_recv_size = 10
            tensor = _build_tensor(send_recv_size, value=rank)
            ctx = profiler_ctx if profiler_ctx is not None else nullcontext()
            with ctx as prof:
                for dst in range(0, world_size):
                    if dst == rank:
                        # Recv mode
                        for src in range(0, world_size):
                            if src == rank:
                                continue
                            output_tensor = _build_tensor(send_recv_size, value=-1)
                            dist.recv(output_tensor, src, tag=src)
                            self.assertTrue(output_tensor.eq(src).all())
                    else:
                        # Send mode
                        dist.send(tensor, dst, tag=rank)

            if profiler_ctx is not None:
                backend = dist.get_backend()
                if backend in SEND_RECV_PROFILING_SUPPORTED_BACKENDS:
                    for event_name in [f"{backend}:send", f"{backend}:recv"]:
                        events = get_profiling_event(event_name, prof)
                        # Each rank sends/recvs from all other ranks
                        event_count = sum(e.count for e in events)
                        expected_event_count = dist.get_world_size() - 1
                        self.assertEqual(event_count, expected_event_count)
                        for event in events:
                            self.assertTrue(event.is_async)
                            self.assertEqual(event.name, event_name)
                            self.assertEqual(event.input_shapes, [[send_recv_size] * 3])

        @skip_but_pass_in_sandcastle_if(
            BACKEND == "nccl", "NCCL send/recv tested by test_send_recv_nccl"
        )
        def test_send_recv_with_tag(self):
            self._test_send_recv_with_tag(profiler_ctx=None)

        @skip_but_pass_in_sandcastle_if(
            BACKEND == "nccl", "NCCL send/recv tested by test_send_recv_nccl"
        )
        def test_send_recv_with_tag_autograd_profiler(self):
            autograd_profiler_ctx = _create_autograd_profiler()
            return self._test_send_recv_with_tag(profiler_ctx=autograd_profiler_ctx)

        @skip_but_pass_in_sandcastle_if(
            BACKEND == "nccl", "NCCL send/recv tested by test_send_recv_nccl"
        )
        @skip_but_pass_in_sandcastle_if(IS_FBCODE, "Kineto in fbcode code causes hang")
        @skip_but_pass_in_sandcastle_if(
            IS_MACOS or IS_WINDOWS,
            "torch.profiler not enabled for mac/windows: https://github.com/pytorch/pytorch/pull/56124",
        )
        def test_send_recv_with_tag_torch_profiler(self):
            torch_profiler_ctx = _create_torch_profiler()
            return self._test_send_recv_with_tag(profiler_ctx=torch_profiler_ctx)

        # ISEND
        def _test_isend(self, profiler_ctx):
            rank = dist.get_rank()
            world_size = dist.get_world_size()
            ctx = profiler_ctx if profiler_ctx is not None else nullcontext()
            with ctx as prof:
                if rank == 0:
                    requests = [
                        dist.isend(_build_tensor(dest, 10), dest)
                        for dest in range(1, world_size)
                    ]
                    for request in requests:
                        request.wait()
                        self.assertTrue(request.is_completed())
                else:
                    tensor = _build_tensor(rank, -1)
                    dist.recv(tensor, 0)
                    self.assertEqual(tensor, _build_tensor(rank, 10))

                self._barrier()

            if profiler_ctx is not None:
                backend = dist.get_backend()
                if backend in SEND_RECV_PROFILING_SUPPORTED_BACKENDS:
                    expected_event_name = (
                        f"{backend}:send" if rank == 0 else f"{backend}:recv"
                    )
                    events = get_profiling_event(expected_event_name, prof)
                    event_count = sum(e.count for e in events)
                    expected_count = dist.get_world_size() - 1 if rank == 0 else 1
                    self.assertEqual(expected_count, event_count)
                    # Event ordering is not guaranteed, so simply ensure the shapes are
                    # found in the following map.
                    expected_shapes = {
                        r: [[r] * 3] for r in range(1, dist.get_world_size())
                    }
                    for event in events:
                        self.assertTrue(event.is_async)
                        self.assertEqual(event.name, expected_event_name)
                        if rank == 0:
                            self.assertTrue(
                                event.input_shapes in expected_shapes.values()
                            )
                        else:
                            self.assertEqual(event.input_shapes, expected_shapes[rank])

        @skip_but_pass_in_sandcastle_if(
            BACKEND == "nccl", "Nccl does not support isend"
        )
        def test_isend(self):
            self._test_isend(profiler_ctx=None)

        @skip_but_pass_in_sandcastle_if(
            BACKEND == "nccl", "Nccl does not support isend"
        )
        def test_isend_autograd_profiler(self):
            autograd_profiler_ctx = _create_autograd_profiler()
            self._test_isend(profiler_ctx=autograd_profiler_ctx)

        @skip_but_pass_in_sandcastle_if(
            BACKEND == "nccl", "Nccl does not support isend"
        )
        @skip_but_pass_in_sandcastle_if(IS_FBCODE, "Kineto in fbcode code causes hang")
        @skip_but_pass_in_sandcastle_if(
            IS_MACOS or IS_WINDOWS,
            "torch.profiler not enabled for mac/windows: https://github.com/pytorch/pytorch/pull/56124",
        )
        def test_isend_torch_profiler(self):
            torch_profiler_ctx = _create_torch_profiler()
            self._test_isend(profiler_ctx=torch_profiler_ctx)

        # IRECV
        @skip_but_pass_in_sandcastle_if(
            BACKEND == "nccl", "Nccl does not support irecv"
        )
        def test_irecv(self):
            rank = dist.get_rank()
            world_size = dist.get_world_size()

            if rank == 0:
                expected_tensors = [
                    _build_tensor(src, -1) for src in range(1, world_size)
                ]
                requests = [
                    dist.irecv(expected_tensors[src - 1], src)
                    for src in range(1, world_size)
                ]

                for src in range(1, world_size):
                    requests[src - 1].wait()
                    self.assertTrue(requests[src - 1].is_completed())
                    self.assertEqual(expected_tensors[src - 1], _build_tensor(src, 10))
            else:
                tensor = _build_tensor(rank, 10)
                dist.send(tensor, 0)

            self._barrier()

        # BROADCAST
        def _test_broadcast_helper(
            self,
            group,
            group_id,
            rank,
            cuda=False,
            rank_to_GPU=None,
            with_options=False,
        ):
            for dtype, value, requires_cuda in [
                (torch.float, -1e-10, False),
                (torch.double, -1e-100, False),
                (torch.half, -0.1, True),
                (torch.int8, -2, False),
                (torch.uint8, 129, False),
                (torch.int, -1e5, False),
                (torch.long, -1e15, False),
            ]:
                if requires_cuda and not cuda:
                    continue
                for src in group:
                    expected_tensor = _build_tensor(src + 1, value, dtype)
                    if cuda:
                        expected_tensor = expected_tensor.cuda(rank_to_GPU[rank][0])
                    if rank == src:
                        if with_options:
                            opts = dist.BroadcastOptions()
                            opts.rootTensor = 0
                            opts.rootRank = src
                            self.call_dist_op(
                                ":broadcast",
                                True,
                                group_id.broadcast,
                                [expected_tensor],
                                opts,
                            )
                        else:
                            self.call_dist_op(
                                ":broadcast",
                                False,
                                dist.broadcast,
                                expected_tensor,
                                src,
                                group_id,
                            )
                    else:
                        tensor = _build_tensor(src + 1, -1, dtype)
                        if cuda:
                            tensor = tensor.cuda(rank_to_GPU[rank][0])
                        if with_options:
                            opts = dist.BroadcastOptions()
                            opts.rootTensor = 0
                            opts.rootRank = src
                            self.call_dist_op(
                                ":broadcast", True, group_id.broadcast, [tensor], opts
                            )
                        else:
                            self.call_dist_op(
                                ":broadcast",
                                False,
                                dist.broadcast,
                                tensor,
                                src,
                                group_id,
                            )
                        self.assertEqual(tensor.size(), expected_tensor.size())
                        self.assertEqual(
                            tensor.ne(expected_tensor).max(), torch.tensor(False)
                        )

            self._barrier()

        @skip_but_pass_in_sandcastle_if(
            BACKEND == "nccl", "Nccl does not support CPU tensors"
        )
        def test_broadcast(self):
            group, group_id, rank = self._init_global_test()
            self._test_broadcast_helper(group, group_id, rank)

        @skip_but_pass_in_sandcastle_if(
            BACKEND != "gloo" and BACKEND != "nccl",
            "Only Gloo and Nccl backend supports CUDA allReduce",
        )
        @skip_if_no_gpu
        def test_broadcast_cuda(self):
            group, group_id, rank = self._init_global_test()
            rank_to_GPU = init_multigpu_helper(dist.get_world_size(), BACKEND)
            device_id = rank_to_GPU[rank][0]
            torch.cuda.set_device(device_id)
            self._test_broadcast_helper(group, group_id, rank, True, rank_to_GPU)

        @skip_if_small_worldsize
        @skip_but_pass_in_sandcastle_if(
            BACKEND == "nccl", "Nccl does not support CPU tensors"
        )
        def test_broadcast_group(self):
            group, group_id, rank = self._init_group_test()
            self._test_broadcast_helper(group, group_id, rank)

        @skip_but_pass_in_sandcastle_if(
            BACKEND == "nccl", "Nccl does not support CPU tensors"
        )
        def test_broadcast_full_group(self):
            group, group_id, rank = self._init_full_group_test()
            self._test_broadcast_helper(group, group_id, rank)

        @skip_but_pass_in_sandcastle_if(
            BACKEND != "nccl",
            "Only NCCL backend supports high priority stream",
        )
        @skip_if_no_gpu
        def test_nccl_high_priority_stream(self):
            group, _, rank = self._init_global_test()
            rank_to_GPU = init_multigpu_helper(dist.get_world_size(), BACKEND)
            device_id = rank_to_GPU[rank][0]
            torch.cuda.set_device(device_id)

            new_port = str(MASTER_PORT + 1)
            os.environ["MASTER_PORT"] = new_port
            gen_iterator = dist.rendezvous("env://", rank, dist.get_world_size())
            store, rank, size = next(gen_iterator)
            store = dist.PrefixStore(new_port, store)

            opts = dist.ProcessGroupNCCL.Options()
            opts.is_high_priority_stream = False
            group_id = dist.ProcessGroupNCCL(store, rank, size, opts)

            self._test_broadcast_helper(group, group_id, rank, True, rank_to_GPU, True)

        # REDUCE
        def _test_reduce_helper(
            self,
            group,
            group_id,
            rank,
            op,
            master_value,
            worker_value,
            expected_value,
            cuda=False,
            rank_to_GPU=None,
        ):
            for src in group:
                tensor = _build_tensor(src + 1).fill_(
                    master_value if rank == src else worker_value
                )
                if cuda:
                    tensor = tensor.cuda(rank_to_GPU[rank][0])
                self.call_dist_op(
                    ":reduce",
                    False,
                    dist.reduce,
                    tensor,
                    src,
                    op,
                    group_id,
                    tensor_shapes=[tensor.shape],
                )
                if rank == src:
                    self.assertEqual(tensor, _build_tensor(src + 1, expected_value))

            self._barrier()

        @skip_but_pass_in_sandcastle_if(
            BACKEND == "nccl", "Nccl does not support CPU tensors"
        )
        @skip_but_pass_in_sandcastle_if(
            BACKEND in DistTestCases.skip_collective["reduce"],
            f"{BACKEND} does not support reduce",
        )
        def test_reduce_sum(self):
            group, group_id, rank = self._init_global_test()
            self._test_reduce_helper(
                group,
                group_id,
                rank,
                dist.ReduceOp.SUM,
                2,
                10,
                2 + (10 * (len(group) - 1)),
            )

        @skip_but_pass_in_sandcastle_if(
            BACKEND != "nccl", "Only Nccl supports CUDA reduce"
        )
        @skip_but_pass_in_sandcastle_if(
            BACKEND in DistTestCases.skip_collective["reduce"],
            f"{BACKEND} does not support reduce",
        )
        @skip_if_no_gpu
        def test_reduce_sum_cuda(self):
            group, group_id, rank = self._init_global_test()
            rank_to_GPU = init_multigpu_helper(dist.get_world_size(), BACKEND)
            device_id = rank_to_GPU[rank][0]
            torch.cuda.set_device(device_id)
            self._test_reduce_helper(
                group,
                group_id,
                rank,
                dist.ReduceOp.SUM,
                2,
                10,
                2 + 10 * (len(group) - 1),
                True,
                rank_to_GPU,
            )

        @skip_but_pass_in_sandcastle_if(
            BACKEND == "nccl", "Nccl does not support CPU tensors"
        )
        @skip_but_pass_in_sandcastle_if(
            BACKEND in DistTestCases.skip_collective["reduce"],
            f"{BACKEND} does not support reduce",
        )
        def test_reduce_product(self):
            group, group_id, rank = self._init_global_test()
            self._test_reduce_helper(
                group,
                group_id,
                rank,
                dist.ReduceOp.PRODUCT,
                2,
                10,
                reduce(operator.mul, [10] * (len(group) - 1), 2),
            )

        @skip_but_pass_in_sandcastle_if(
            BACKEND == "nccl", "Nccl does not support CPU tensors"
        )
        @skip_but_pass_in_sandcastle_if(
            BACKEND in DistTestCases.skip_collective["reduce"],
            f"{BACKEND} does not support reduce",
        )
        def test_reduce_min(self):
            group, group_id, rank = self._init_global_test()
            self._test_reduce_helper(
                group, group_id, rank, dist.ReduceOp.MIN, 1010, 1, 1
            )

        @skip_but_pass_in_sandcastle_if(
            BACKEND == "nccl", "Nccl does not support CPU tensors"
        )
        @skip_but_pass_in_sandcastle_if(
            BACKEND in DistTestCases.skip_collective["reduce"],
            f"{BACKEND} does not support reduce",
        )
        def test_reduce_max(self):
            group, group_id, rank = self._init_global_test()
            self._test_reduce_helper(
                group, group_id, rank, dist.ReduceOp.MAX, -1, 10, 10
            )

        @skip_but_pass_in_sandcastle_if(
            BACKEND == "nccl", "Nccl does not support CPU tensors"
        )
        @skip_but_pass_in_sandcastle_if(
            BACKEND in DistTestCases.skip_collective["reduce"],
            f"{BACKEND} does not support reduce",
        )
        @skip_if_small_worldsize
        def test_reduce_group_sum(self):
            group, group_id, rank = self._init_group_test()
            self._test_reduce_helper(
                group,
                group_id,
                rank,
                dist.ReduceOp.SUM,
                2,
                10,
                2 + (10 * (len(group) - 1)),
            )

        @skip_but_pass_in_sandcastle_if(
            BACKEND == "nccl", "Nccl does not support CPU tensors"
        )
        @skip_but_pass_in_sandcastle_if(
            BACKEND in DistTestCases.skip_collective["reduce"],
            f"{BACKEND} does not support reduce",
        )
        @skip_if_small_worldsize
        def test_reduce_group_product(self):
            group, group_id, rank = self._init_group_test()
            self._test_reduce_helper(
                group,
                group_id,
                rank,
                dist.ReduceOp.PRODUCT,
                2,
                10,
                reduce(operator.mul, [10] * (len(group) - 1), 2),
            )

        @skip_but_pass_in_sandcastle_if(
            BACKEND == "nccl", "Nccl does not support CPU tensors"
        )
        @skip_but_pass_in_sandcastle_if(
            BACKEND in DistTestCases.skip_collective["reduce"],
            f"{BACKEND} does not support reduce",
        )
        @skip_if_small_worldsize
        def test_reduce_group_min(self):
            group, group_id, rank = self._init_group_test()
            self._test_reduce_helper(
                group, group_id, rank, dist.ReduceOp.MIN, 1010, 1, 1
            )

        @skip_but_pass_in_sandcastle_if(
            BACKEND == "nccl", "Nccl does not support CPU tensors"
        )
        @skip_but_pass_in_sandcastle_if(
            BACKEND in DistTestCases.skip_collective["reduce"],
            f"{BACKEND} does not support reduce",
        )
        @skip_if_small_worldsize
        def test_reduce_group_max(self):
            group, group_id, rank = self._init_group_test()
            self._test_reduce_helper(
                group, group_id, rank, dist.ReduceOp.MAX, -1, 10, 10
            )

        @skip_but_pass_in_sandcastle_if(
            BACKEND == "nccl", "Nccl does not support CPU tensors"
        )
        @skip_but_pass_in_sandcastle_if(
            BACKEND in DistTestCases.skip_collective["reduce"],
            f"{BACKEND} does not support reduce",
        )
        def test_reduce_full_group_sum(self):
            group, group_id, rank = self._init_full_group_test()
            self._test_reduce_helper(
                group,
                group_id,
                rank,
                dist.ReduceOp.SUM,
                2,
                10,
                2 + (10 * (len(group) - 1)),
            )

        @skip_but_pass_in_sandcastle_if(
            BACKEND == "nccl", "Nccl does not support CPU tensors"
        )
        @skip_but_pass_in_sandcastle_if(
            BACKEND in DistTestCases.skip_collective["reduce"],
            f"{BACKEND} does not support reduce",
        )
        def test_reduce_full_group_product(self):
            group, group_id, rank = self._init_full_group_test()
            self._test_reduce_helper(
                group,
                group_id,
                rank,
                dist.ReduceOp.PRODUCT,
                2,
                10,
                reduce(operator.mul, [10] * (len(group) - 1), 2),
            )

        @skip_but_pass_in_sandcastle_if(
            BACKEND == "nccl", "Nccl does not support CPU tensors"
        )
        @skip_but_pass_in_sandcastle_if(
            BACKEND in DistTestCases.skip_collective["reduce"],
            f"{BACKEND} does not support reduce",
        )
        def test_reduce_full_group_min(self):
            group, group_id, rank = self._init_full_group_test()
            self._test_reduce_helper(
                group, group_id, rank, dist.ReduceOp.MIN, 1010, 1, 1
            )

        @skip_but_pass_in_sandcastle_if(
            BACKEND == "nccl", "Nccl does not support CPU tensors"
        )
        @skip_but_pass_in_sandcastle_if(
            BACKEND in DistTestCases.skip_collective["reduce"],
            f"{BACKEND} does not support reduce",
        )
        def test_reduce_full_group_max(self):
            group, group_id, rank = self._init_full_group_test()
            self._test_reduce_helper(
                group, group_id, rank, dist.ReduceOp.MAX, -1, 10, 10
            )

        # REDUCE TWICE
        def _test_reduce_twice_helper(
            self,
            group,
            group_id,
            rank,
            op,
            master_value,
            worker_value,
            expected_value,
            cuda=False,
            rank_to_GPU=None,
        ):
            for src in group:
                tensors = [
                    _build_tensor(src + 1).fill_(
                        master_value if rank == src else worker_value
                    )
                    for i in range(2)
                ]
                if cuda:
                    for i in range(2):
                        tensors[i] = tensors[i].cuda(rank_to_GPU[rank][0])
                self.call_dist_op(
                    ":reduce",
                    False,
                    dist.reduce,
                    tensors[0],
                    src,
                    op,
                    group_id,
                    secondary_op_call=lambda: dist.reduce(
                        tensors[1], src, op, group_id
                    ),
                    tensor_shapes=[tensors[0].shape],
                )
                if rank == src:
                    for tensor in tensors:
                        self.assertEqual(tensor, _build_tensor(src + 1, expected_value))

            self._barrier()

        @skip_but_pass_in_sandcastle_if(
            BACKEND == "nccl", "Nccl does not support CPU tensors"
        )
        @skip_but_pass_in_sandcastle_if(
            BACKEND in DistTestCases.skip_collective["reduce"],
            f"{BACKEND} does not support reduce",
        )
        def test_reduce_sum_twice(self):
            group, group_id, rank = self._init_global_test()
            self._test_reduce_twice_helper(
                group,
                group_id,
                rank,
                dist.ReduceOp.SUM,
                2,
                10,
                2 + (10 * (len(group) - 1)),
            )

        @skip_but_pass_in_sandcastle_if(
            BACKEND != "nccl", "Only Nccl supports CUDA reduce"
        )
        @skip_but_pass_in_sandcastle_if(
            BACKEND in DistTestCases.skip_collective["reduce"],
            f"{BACKEND} does not support reduce",
        )
        @skip_if_no_gpu
        def test_reduce_sum_cuda_twice(self):
            group, group_id, rank = self._init_global_test()
            rank_to_GPU = init_multigpu_helper(dist.get_world_size(), BACKEND)
            device_id = rank_to_GPU[rank][0]
            torch.cuda.set_device(device_id)
            self._test_reduce_twice_helper(
                group,
                group_id,
                rank,
                dist.ReduceOp.SUM,
                2,
                10,
                2 + 10 * (len(group) - 1),
                True,
                rank_to_GPU,
            )

        @skip_but_pass_in_sandcastle_if(
            BACKEND != "nccl", "Only Nccl supports reduce_scatter_v"
        )
        @skip_but_pass_in_sandcastle_if(
            BACKEND in DistTestCases.skip_collective["reduce"],
            f"{BACKEND} does not support reduce",
        )
        @skip_if_no_gpu
        def test_reduce_scatter_v_cuda(self):
            self._barrier()
            group, group_id, rank = self._init_global_test()
            rank_to_GPU = init_multigpu_helper(dist.get_world_size(), BACKEND)
            device_id = rank_to_GPU[rank][0]

            input_split_sizes = [src + 1 for src in group]
            start_len = sum(input_split_sizes[:rank])
            end_len = start_len + input_split_sizes[rank]
            sum_len = sum(input_split_sizes)
            master_value = 2
            worker_value = 10

            for async_val in [True, False]:
                tensor = _build_tensor(sum_len, worker_value, device_id=device_id)
                tensor[start_len:end_len].fill_(master_value)
                out_tensor = (
                    torch.empty(
                        input_split_sizes[rank], sum_len, sum_len, dtype=torch.float
                    )
                    .fill_(-1)
                    .cuda(device_id)
                )

                req = dist.reduce_scatter(
                    out_tensor,
                    list(torch.split(tensor, input_split_sizes)),
                    dist.ReduceOp.SUM,
                    group_id,
                    async_val,
                )
                if async_val:
                    req.wait()

                expected_value = 2 + (10 * (len(group) - 1))
                expected_tensor = torch.empty(
                    input_split_sizes[rank], sum_len, sum_len, dtype=torch.float
                )
                expected_tensor = expected_tensor.fill_(expected_value).cuda(device_id)

                self.assertEqual(out_tensor, expected_tensor)
            self._barrier()

        # Test reduce_scatter_tensor accepting single tensor as input
        def _reduce_scatter_tensor_helper(
            self, tensor_out, tensor_in, group_id, rank, cuda=True, rank_to_GPU=None
        ):
            if cuda:
                tensor_in = tensor_in.cuda(rank_to_GPU[rank][0])
                tensor_out = tensor_out.cuda(rank_to_GPU[rank][0])
            tensor_shapes = [tensor_out.shape]
            self.call_dist_op(
                ":reduce_scatter_tensor",
                False,
                dist.reduce_scatter_tensor,
                tensor_out,
                tensor_in,
                dist.ReduceOp.SUM,
                group_id,
                False,
                expect_event=False,
                tensor_shapes=tensor_shapes,
            )
            return tensor_out

        @skip_but_pass_in_sandcastle_if(
            BACKEND != "nccl", "Only Nccl supports CUDA reduce_scatter_tensor"
        )
        @skip_if_no_gpu
        def test_reduce_scatter_tensor_cuda(self):
            group, group_id, rank = self._init_global_test()
            rank_to_GPU = init_multigpu_helper(dist.get_world_size(), BACKEND)
            size = 2
            tensor_out = torch.zeros(size, dtype=torch.int64)

            # Concatenated input
            tensor_in = torch.arange(len(group) * size)
            tensor_out = self._reduce_scatter_tensor_helper(
                tensor_out, tensor_in, group_id, rank, True, rank_to_GPU
            )
            # Check result
            expected_tensor = torch.arange(rank * size, (rank + 1) * size) * len(group)
            self.assertEqual(tensor_out, expected_tensor)
            self._barrier()

            # Stacked input
            tensor_in = torch.reshape(tensor_in, (len(group), size))
            tensor_out = self._reduce_scatter_tensor_helper(
                tensor_out, tensor_in, group_id, rank, True, rank_to_GPU
            )
            # Check result
            # Should be the same as the result in concatenated case
            self.assertEqual(tensor_out, expected_tensor)
            self._barrier()

        def call_dist_op(
            self,
            profiling_title_postfix,
            is_async,
            op,
            *args,
            expect_event=True,
            secondary_op_call=None,
            profile_cuda=False,
            tensor_shapes=None,
            **kwargs,
        ):
            op_calls = [lambda: op(*args, **kwargs)]
            if secondary_op_call is not None:
                op_calls.append(secondary_op_call)

            autograd_profiler_ctx = torch.autograd.profiler.profile(
                use_cuda=profile_cuda, record_shapes=True
            )

            # TODO: move this test to use torch.profiler once kineto issues are
            # fixed internally.
            with autograd_profiler_ctx:
                works = [op_call() for op_call in op_calls]
                if is_async:
                    for work in works:
                        work.wait()

            if expect_event and dist.get_backend() in PROFILING_SUPPORTED_BACKENDS:
                # We are only interested in the backend's implementation not the dispatcher wrapper.
                events = get_profiling_event(
                    dist.get_backend() + profiling_title_postfix, autograd_profiler_ctx
                )
                # DETAIL debug mode can use a pg wrapper that issues more collectives
                # under the hood
                if dist.get_debug_level() != dist.DebugLevel.DETAIL:
                    self.assertEqual(len(events), len(op_calls))
                for e in events:
                    self.assertTrue(e.is_async)
                    self.assertEqual(e.count, 1)
                    self.assertGreaterEqual(e.cpu_time, 0)
                    # Verify tensor shapes if given
                    # DETAIL debug mode can use a pg wrapper that issues more collectives
                    # under the hood
                    if (
                        tensor_shapes is not None
                        and dist.get_debug_level() != dist.DebugLevel.DETAIL
                    ):
                        self.assertEqual(
                            e.input_shapes,
                            tensor_shapes,
                            f"event shape: {e.input_shapes} vs tensor {tensor_shapes}",
                        )

        # ALL REDUCE
        def _test_all_reduce_helper(
            self,
            group,
            group_id,
            rank,
            op,
            master_value,
            worker_value,
            expected_value,
            cuda=False,
            rank_to_GPU=None,
            dtype=torch.float,
            async_op=False,
        ):
            for src in group:
                curr_value = master_value if rank == src else worker_value

                tensor = _build_tensor(src + 1, dtype=dtype).fill_(curr_value)
                if cuda:
                    tensor = tensor.cuda(rank_to_GPU[rank][0])
                if tensor.dtype == torch.complex64:
                    tensor_shapes = [torch.view_as_real(tensor).shape]
                else:
                    tensor_shapes = [tensor.shape]
                self.call_dist_op(
                    ":all_reduce",
                    async_op,
                    dist.all_reduce,
                    tensor,
                    op,
                    group_id,
                    async_op=async_op,
                    tensor_shapes=tensor_shapes,
                )
                # Currently, only Gloo backend has profiling tested with CUDA enabled.
                # Only run cuda profiling test for one rank to speed up since
                # running with different src_rank does not affect the correctness.
                if (
                    src == 0
                    and cuda
                    and dist.get_backend() in CUDA_PROFILING_SUPPORTED_BACKENDS
                ):
                    self.call_dist_op(
                        ":all_reduce",
                        async_op,
                        dist.all_reduce,
                        tensor,
                        op,
                        group_id,
                        async_op=async_op,
                        profile_cuda=True,
                        tensor_shapes=tensor_shapes,
                    )

            self._barrier()

        @skip_but_pass_in_sandcastle_if(
            BACKEND == "nccl", "Nccl does not support CPU tensors"
        )
        def test_all_reduce_sum(self):
            group, group_id, rank = self._init_global_test()
            self._test_all_reduce_helper(
                group,
                group_id,
                rank,
                dist.ReduceOp.SUM,
                2,
                10,
                2 + (10 * (len(group) - 1)),
            )

        @skip_but_pass_in_sandcastle_if(
            BACKEND == "nccl", "Nccl does not support CPU tensors"
        )
        def test_all_reduce_sum_async(self):
            group, group_id, rank = self._init_global_test()
            self._test_all_reduce_helper(
                group,
                group_id,
                rank,
                dist.ReduceOp.SUM,
                2,
                10,
                2 + (10 * (len(group) - 1)),
                async_op=True,
            )

        @skip_but_pass_in_sandcastle_if(
            BACKEND != "gloo" and BACKEND != "nccl",
            "Only Gloo and NCCL backends will have CUDA allReduce tested",
        )
        @skip_if_no_gpu
        def test_all_reduce_sum_cuda(self):
            torch.cuda.set_device(self.rank)
            group, group_id, rank = self._init_global_test()
            rank_to_GPU = init_multigpu_helper(dist.get_world_size(), BACKEND)
            self._test_all_reduce_helper(
                group,
                group_id,
                rank,
                dist.ReduceOp.SUM,
                2,
                10,
                2 + (10 * (len(group) - 1)),
                True,
                rank_to_GPU,
            )

        @skip_but_pass_in_sandcastle_if(
            BACKEND != "gloo" and BACKEND != "nccl",
            "Only Gloo and NCCL backends will have CUDA allReduce tested",
        )
        @skip_if_no_gpu
        def test_all_reduce_sum_cuda_async(self):
            torch.cuda.set_device(self.rank)
            group, group_id, rank = self._init_global_test()
            rank_to_GPU = init_multigpu_helper(dist.get_world_size(), BACKEND)
            self._test_all_reduce_helper(
                group,
                group_id,
                rank,
                dist.ReduceOp.SUM,
                2,
                10,
                2 + (10 * (len(group) - 1)),
                True,
                rank_to_GPU,
                async_op=True,
            )

        @skip_but_pass_in_sandcastle_if(
            BACKEND == "nccl", "Nccl does not support CPU tensors"
        )
        def test_all_reduce_sum_complex(self):
            group, group_id, rank = self._init_global_test()
            self._test_all_reduce_helper(
                group,
                group_id,
                rank,
                dist.ReduceOp.SUM,
                complex(2, 3),
                complex(10, 11),
                complex(2, 3) + (complex(10, 11) * (len(group) - 1)),
                dtype=torch.cfloat,
            )

        @skip_but_pass_in_sandcastle_if(
            BACKEND == "nccl", "Nccl does not support CPU tensors"
        )
        def test_all_reduce_complex_unsupported_ops(self):
            unsupported_ops = [
                dist.ReduceOp.MAX,
                dist.ReduceOp.MIN,
                dist.ReduceOp.PRODUCT,
                dist.ReduceOp.BAND,
                dist.ReduceOp.BOR,
                dist.ReduceOp.BXOR,
            ]
            _group, group_id, _rank = self._init_global_test()
            for unsupported_op in unsupported_ops:
                with self.assertRaisesRegex(
                    ValueError, "all_reduce does not support"
                ):
                    dist.all_reduce(
                        _build_tensor(1, dtype=torch.cfloat), unsupported_op, group_id
                    )

        @skip_but_pass_in_sandcastle_if(
            BACKEND != "gloo" and BACKEND != "nccl",
            "Only Gloo and NCCL backends will have CUDA allReduce tested",
        )
        @skip_if_no_gpu
        def test_all_reduce_sum_cuda_complex(self):
            torch.cuda.set_device(self.rank)
            group, group_id, rank = self._init_global_test()
            rank_to_GPU = init_multigpu_helper(dist.get_world_size(), BACKEND)
            self._test_all_reduce_helper(
                group,
                group_id,
                rank,
                dist.ReduceOp.SUM,
                complex(2, 3),
                complex(10, 11),
                complex(2, 3) + (complex(10, 11) * (len(group) - 1)),
                True,
                rank_to_GPU,
                dtype=torch.cfloat,
            )

        @skip_but_pass_in_sandcastle_if(
            BACKEND == "nccl", "Nccl does not support CPU tensors"
        )
        def test_all_reduce_product(self):
            group, group_id, rank = self._init_global_test()
            self._test_all_reduce_helper(
                group,
                group_id,
                rank,
                dist.ReduceOp.PRODUCT,
                2,
                10,
                reduce(operator.mul, [10] * (len(group) - 1), 2),
            )

        @skip_but_pass_in_sandcastle_if(
            BACKEND == "nccl", "Nccl does not support CPU tensors"
        )
        def test_all_reduce_min(self):
            group, group_id, rank = self._init_global_test()
            self._test_all_reduce_helper(
                group, group_id, rank, dist.ReduceOp.MIN, 1010, 1, 1
            )

        @skip_but_pass_in_sandcastle_if(
            BACKEND == "nccl", "Nccl does not support CPU tensors"
        )
        def test_all_reduce_max(self):
            group, group_id, rank = self._init_global_test()
            self._test_all_reduce_helper(
                group, group_id, rank, dist.ReduceOp.MAX, -1, 10, 10
            )

        @skip_if_small_worldsize
        @skip_but_pass_in_sandcastle_if(
            BACKEND == "nccl", "Nccl does not support CPU tensors"
        )
        def test_all_reduce_group_sum(self):
            group, group_id, rank = self._init_group_test()
            self._test_all_reduce_helper(
                group,
                group_id,
                rank,
                dist.ReduceOp.SUM,
                2,
                10,
                2 + (10 * (len(group) - 1)),
            )

        @skip_if_small_worldsize
        @skip_but_pass_in_sandcastle_if(
            BACKEND == "nccl", "Nccl does not support CPU tensors"
        )
        def test_all_reduce_group_product(self):
            group, group_id, rank = self._init_group_test()
            self._test_all_reduce_helper(
                group,
                group_id,
                rank,
                dist.ReduceOp.PRODUCT,
                2,
                10,
                reduce(operator.mul, [10] * (len(group) - 1), 2),
            )

        @skip_if_small_worldsize
        @skip_but_pass_in_sandcastle_if(
            BACKEND == "nccl", "Nccl does not support CPU tensors"
        )
        def test_all_reduce_group_min(self):
            group, group_id, rank = self._init_group_test()
            self._test_all_reduce_helper(
                group, group_id, rank, dist.ReduceOp.MIN, 1010, 1, 1
            )

        @skip_if_small_worldsize
        @skip_but_pass_in_sandcastle_if(
            BACKEND == "nccl", "Nccl does not support CPU tensors"
        )
        def test_all_reduce_group_max(self):
            group, group_id, rank = self._init_group_test()
            self._test_all_reduce_helper(
                group, group_id, rank, dist.ReduceOp.MAX, -1, 10, 10
            )

        @skip_but_pass_in_sandcastle_if(
            BACKEND == "nccl", "Nccl does not support CPU tensors"
        )
        def test_all_reduce_full_group_sum(self):
            group, group_id, rank = self._init_full_group_test()
            self._test_all_reduce_helper(
                group,
                group_id,
                rank,
                dist.ReduceOp.SUM,
                2,
                10,
                2 + (10 * (len(group) - 1)),
            )

        @skip_but_pass_in_sandcastle_if(
            BACKEND == "nccl", "Nccl does not support CPU tensors"
        )
        def test_all_reduce_full_group_product(self):
            group, group_id, rank = self._init_full_group_test()
            self._test_all_reduce_helper(
                group,
                group_id,
                rank,
                dist.ReduceOp.PRODUCT,
                2,
                10,
                reduce(operator.mul, [10] * (len(group) - 1), 2),
            )

        @skip_but_pass_in_sandcastle_if(
            BACKEND == "nccl", "Nccl does not support CPU tensors"
        )
        def test_all_reduce_full_group_min(self):
            group, group_id, rank = self._init_full_group_test()
            self._test_all_reduce_helper(
                group, group_id, rank, dist.ReduceOp.MIN, 1010, 1, 1
            )

        @skip_but_pass_in_sandcastle_if(
            BACKEND == "nccl", "Nccl does not support CPU tensors"
        )
        def test_all_reduce_full_group_max(self):
            group, group_id, rank = self._init_full_group_test()
            self._test_all_reduce_helper(
                group, group_id, rank, dist.ReduceOp.MAX, -1, 10, 10
            )

        # SPARSE ALL REDUCE
        def _test_sparse_all_reduce_sum(self, fn):
            _group, group_id, rank = self._init_global_test()

            tests = simple_sparse_reduce_tests(
                rank, dist.get_world_size(), num_inputs=1
            )
            for inputs, outputs in tests:
                tensors = [fn(input) for input in inputs]
                dist.all_reduce(tensors[0], dist.ReduceOp.SUM, group_id)
                self.assertEqual(tensors[0], outputs[0])

        @skip_but_pass_in_sandcastle_if(
            BACKEND != "gloo", "Only Gloo backend support sparse all reduce"
        )
        def test_sparse_all_reduce_sum(self):
            self._test_sparse_all_reduce_sum(lambda t: t)

        @skip_but_pass_in_sandcastle_if(
            BACKEND != "gloo", "Only Gloo backend support sparse all reduce"
        )
        @skip_if_no_gpu
        def test_sparse_all_reduce_sum_cuda(self):
            self._test_sparse_all_reduce_sum(lambda t: t.clone().cuda())

        # ALL REDUCE - COALESCED
        @staticmethod
        def _all_reduce_coalesced_sum_test_cases(group_size):
            return (
                [2, 3, complex(2, 3)],
                [10, 11, complex(10, 11)],
                [
                    2 + 10 * (group_size - 1),
                    3 + 11 * (group_size - 1),
                    complex(2, 3) + complex(10, 11) * (group_size - 1),
                ],
                [torch.float, torch.float, torch.cfloat],
            )

        @staticmethod
        def _all_reduce_coalesced_product_test_cases(group_size):
            return (
                [1, 2],
                [3, 4],
                [1 * 3 ** (group_size - 1), 2 * 4 ** (group_size - 1)],
                [torch.float, torch.float],
            )

        @staticmethod
        def _all_reduce_coalesced_min_test_cases(group_size):
            return (
                [1, 4],
                [2, 3],
                [1, 3],
                [torch.float, torch.float],
            )

        @staticmethod
        def _all_reduce_coalesced_max_test_cases(group_size):
            return (
                [1, 4],
                [2, 3],
                [2, 4],
                [torch.float, torch.float],
            )

        @skip_but_pass_in_sandcastle_if(
            BACKEND == "nccl", "Nccl does not support CPU tensors"
        )
        def test_all_reduce_coalesced_max_complex_unsupported(self):
            _group, group_id, _rank = self._init_global_test()
            with self.assertRaisesRegex(ValueError, "all_reduce does not support"):
                dist.all_reduce_coalesced(
                    [_build_tensor(1, dtype=torch.cfloat)], dist.ReduceOp.MAX, group_id
                )

        def _test_all_reduce_coalesced_helper(
            self,
            group,
            group_id,
            rank,
            op,
            cuda=False,
            rank_to_GPU=None,
        ):
            test_case_func = {
                dist.ReduceOp.SUM: self._all_reduce_coalesced_sum_test_cases,
                dist.ReduceOp.PRODUCT: self._all_reduce_coalesced_product_test_cases,
                dist.ReduceOp.MIN: self._all_reduce_coalesced_min_test_cases,
                dist.ReduceOp.MAX: self._all_reduce_coalesced_max_test_cases,
            }[op]

            master_values, worker_values, expected_values, dtypes = test_case_func(
                len(group)
            )

            for src in group:
                curr_values = master_values if rank == src else worker_values
                tensors = [
                    _build_tensor(src + 1, val, dtype=dtype)
                    for dtype, val in zip(dtypes, curr_values)
                ]
                if cuda:
                    tensors = [t.cuda(rank_to_GPU[rank][0]) for t in tensors]
                tensor_shapes = []
                for tensor in tensors:
                    if tensor.dtype == torch.complex64:
                        tensor_shapes.append(torch.view_as_real(tensor).shape)
                    else:
                        tensor_shapes.append(tensor.shape)
                self.call_dist_op(
                    ":all_reduce",
                    False,
                    dist.all_reduce_coalesced,
                    tensors,
                    op,
                    group_id,
                    tensor_shapes=tensor_shapes,
                )
                expected_tensors = [
                    _build_tensor(src + 1, expected_value, dtype=dtype)
                    for dtype, expected_value in zip(dtypes, expected_values)
                ]
                self.assertEqual(tensors, expected_tensors)

            self._barrier()

        @require_backend_is_available({"gloo"})
        def test_all_reduce_coalesced_sum(self):
            group, group_id, rank = self._init_global_test()
            self._test_all_reduce_coalesced_helper(
                group,
                group_id,
                rank,
                dist.ReduceOp.SUM,
                cuda=False,
                rank_to_GPU=None,
            )

        @require_backend_is_available({"gloo"})
        def test_all_reduce_coalesced_product(self):
            group, group_id, rank = self._init_global_test()
            self._test_all_reduce_coalesced_helper(
                group,
                group_id,
                rank,
                dist.ReduceOp.PRODUCT,
                cuda=False,
                rank_to_GPU=None,
            )

        @require_backend_is_available({"gloo"})
        def test_all_reduce_coalesced_min(self):
            group, group_id, rank = self._init_global_test()
            self._test_all_reduce_coalesced_helper(
                group,
                group_id,
                rank,
                dist.ReduceOp.MIN,
                cuda=False,
                rank_to_GPU=None,
            )

        @require_backend_is_available({"gloo"})
        def test_all_reduce_coalesced_max(self):
            group, group_id, rank = self._init_global_test()
            self._test_all_reduce_coalesced_helper(
                group, group_id, rank, dist.ReduceOp.MAX, cuda=False, rank_to_GPU=None
            )

        @skip_if_small_worldsize
        @require_backend_is_available({"gloo"})
        def test_all_reduce_coalesced_group_sum(self):
            group, group_id, rank = self._init_group_test()
            self._test_all_reduce_coalesced_helper(
                group, group_id, rank, dist.ReduceOp.SUM, cuda=False, rank_to_GPU=None
            )

        @skip_if_small_worldsize
        @require_backend_is_available({"gloo"})
        def test_all_reduce_coalesced_group_product(self):
            group, group_id, rank = self._init_group_test()
            self._test_all_reduce_coalesced_helper(
                group,
                group_id,
                rank,
                dist.ReduceOp.PRODUCT,
                cuda=False,
                rank_to_GPU=None,
            )

        @skip_if_small_worldsize
        @require_backend_is_available({"gloo"})
        def test_all_reduce_coalesced_group_min(self):
            group, group_id, rank = self._init_group_test()
            self._test_all_reduce_coalesced_helper(
                group, group_id, rank, dist.ReduceOp.MIN, cuda=False, rank_to_GPU=None
            )

        @skip_if_small_worldsize
        @require_backend_is_available({"gloo"})
        def test_all_reduce_coalesced_group_max(self):
            group, group_id, rank = self._init_group_test()
            self._test_all_reduce_coalesced_helper(
                group, group_id, rank, dist.ReduceOp.MAX, cuda=False, rank_to_GPU=None
            )

        @require_backend_is_available({"gloo"})
        def test_all_reduce_coalesced_full_group_sum(self):
            group, group_id, rank = self._init_full_group_test()
            self._test_all_reduce_coalesced_helper(
                group, group_id, rank, dist.ReduceOp.SUM, cuda=False, rank_to_GPU=None
            )

        @require_backend_is_available({"gloo"})
        def test_all_reduce_coalesced_full_group_product(self):
            group, group_id, rank = self._init_full_group_test()
            self._test_all_reduce_coalesced_helper(
                group,
                group_id,
                rank,
                dist.ReduceOp.PRODUCT,
                cuda=False,
                rank_to_GPU=None,
            )

        @require_backend_is_available({"gloo"})
        def test_all_reduce_coalesced_full_group_min(self):
            group, group_id, rank = self._init_full_group_test()
            self._test_all_reduce_coalesced_helper(
                group,
                group_id,
                rank,
                dist.ReduceOp.MIN,
                cuda=False,
                rank_to_GPU=None,
            )

        @require_backend_is_available({"gloo"})
        def test_all_reduce_coalesced_full_group_max(self):
            group, group_id, rank = self._init_full_group_test()
            self._test_all_reduce_coalesced_helper(
                group, group_id, rank, dist.ReduceOp.MAX, cuda=False, rank_to_GPU=None
            )

        # SCATTER
        def _test_scatter_helper(
            self, group, group_id, rank, cuda=False, rank_to_GPU=None, dtype=torch.float
        ):
            for dest in group:
                tensor = _build_tensor(dest + 1, -1, dtype=dtype)
                expected_tensor = _build_tensor(dest + 1, rank, dtype=dtype)
                tensors = (
                    [_build_tensor(dest + 1, i, dtype=dtype) for i in group]
                    if rank == dest
                    else []
                )
                if cuda:
                    tensor = tensor.cuda(rank_to_GPU[rank][0])
                    tensors = [t.cuda(rank_to_GPU[rank][0]) for t in tensors]
                if dtype == torch.complex64:
                    tensor_shapes = [torch.view_as_real(t).shape for t in tensors]
                else:
                    tensor_shapes = [t.shape for t in tensors]
                self.call_dist_op(
                    ":scatter",
                    False,
                    dist.scatter,
                    tensor,
                    src=dest,
                    scatter_list=tensors,
                    group=group_id,
                    expect_event=False,
                    tensor_shapes=tensor_shapes,
                )
                self.assertEqual(tensor, expected_tensor)

            self._barrier()

        @skip_but_pass_in_sandcastle_if(
            BACKEND == "nccl", "Nccl does not support CPU tensors"
        )
        @skip_but_pass_in_sandcastle_if(
            BACKEND == "ucc", "CPU tensor ops not supported by UCP TL"
        )
        def test_scatter_checks(self):
            group, _group_id, rank = self._init_global_test()
            one = torch.ones([1])

            # Specify scatter_list argument only on source rank.
            output = one.clone() * -1
            if rank == 0:
                scatter_list = [one.clone() * i for i in group]
                dist.scatter(output, src=0, scatter_list=scatter_list)
            else:
                dist.scatter(output, src=0)
            self.assertEqual(output, one * rank)

            # Don't specify src argument.
            output = one.clone() * -1
            if rank == 0:
                scatter_list = [one.clone() * i for i in group]
                dist.scatter(output, scatter_list=scatter_list)
            else:
                dist.scatter(output)
            self.assertEqual(output, one * rank)

        @skip_but_pass_in_sandcastle_if(
            BACKEND == "nccl", "Nccl does not support CPU tensors"
        )
        @skip_but_pass_in_sandcastle_if(
            BACKEND == "ucc", "CPU tensor ops not supported by UCP TL"
        )
        def test_scatter(self):
            group, group_id, rank = self._init_global_test()
            self._test_scatter_helper(group, group_id, rank)

        @skip_but_pass_in_sandcastle_if(
            BACKEND != "nccl", "Only Nccl supports CUDA gather"
        )
        @skip_if_no_gpu
        def test_scatter_cuda(self):
            group, group_id, rank = self._init_global_test()
            rank_to_GPU = init_multigpu_helper(dist.get_world_size(), BACKEND)
            self._test_scatter_helper(group, group_id, rank, True, rank_to_GPU)

        @skip_but_pass_in_sandcastle_if(
            BACKEND == "nccl", "Nccl does not support CPU tensors"
        )
        @skip_but_pass_in_sandcastle_if(
            BACKEND == "ucc", "CPU tensor ops not supported by UCP TL"
        )
        def test_scatter_complex(self):
            group, group_id, rank = self._init_global_test()
            self._test_scatter_helper(group, group_id, rank, dtype=torch.cfloat)

        @skip_but_pass_in_sandcastle_if(
            BACKEND != "nccl", "Only Nccl supports CUDA gather"
        )
        @skip_if_no_gpu
        def test_scatter_cuda_complex(self):
            group, group_id, rank = self._init_global_test()
            rank_to_GPU = init_multigpu_helper(dist.get_world_size(), BACKEND)
            self._test_scatter_helper(
                group, group_id, rank, True, rank_to_GPU, dtype=torch.cfloat
            )

        @skip_but_pass_in_sandcastle_if(
            BACKEND == "nccl", "Nccl does not support CPU tensors"
        )
        @skip_but_pass_in_sandcastle_if(
            BACKEND == "ucc", "CPU tensor ops not supported by UCP TL"
        )
        @skip_if_small_worldsize
        def test_scatter_group(self):
            group, group_id, rank = self._init_group_test()
            self._test_scatter_helper(group, group_id, rank)

        @skip_but_pass_in_sandcastle_if(
            BACKEND == "nccl", "Nccl does not support CPU tensors"
        )
        @skip_but_pass_in_sandcastle_if(
            BACKEND == "ucc", "CPU tensor ops not supported by UCP TL"
        )
        def test_scatter_full_group(self):
            group, group_id, rank = self._init_full_group_test()
            self._test_scatter_helper(group, group_id, rank)

        # GATHER
        def _test_gather_helper(
            self, group, group_id, rank, cuda=False, rank_to_GPU=None
        ):
            for dest in group:
                tensor = _build_tensor(dest + 1, rank)
                tensors = (
                    [_build_tensor(dest + 1, -1) for i in group] if rank == dest else []
                )
                if cuda:
                    tensor = tensor.cuda(rank_to_GPU[rank][0])
                    tensors = [t.cuda(rank_to_GPU[rank][0]) for t in tensors]
                self.call_dist_op(
                    ":gather",
                    False,
                    dist.gather,
                    tensor,
                    dst=dest,
                    gather_list=tensors,
                    group=group_id,
                    expect_event=False,
                    tensor_shapes=[tensors[0].shape] if len(tensors) > 0 else None,
                )
                if rank == dest:
                    expected_tensors = [_build_tensor(dest + 1, i) for i in group]
                    for t1, t2 in zip(tensors, expected_tensors):
                        self.assertEqual(t1, t2)

            self._barrier()

        @skip_but_pass_in_sandcastle_if(
            BACKEND == "nccl", "Nccl does not support CPU tensors"
        )
        @skip_but_pass_in_sandcastle_if(
            BACKEND == "ucc", "CPU tensor ops not supported by UCP TL"
        )
        def test_gather_checks(self):
            group, _group_id, rank = self._init_global_test()
            one = torch.ones([1])

            # Specify gather_list argument only on destination rank.
            if rank == 0:
                gather_list = [one.clone() for _ in group]
                dist.gather(one * rank, dst=0, gather_list=gather_list)
                for i in group:
                    self.assertEqual(gather_list[i], one * i)
            else:
                dist.gather(one * rank, dst=0)

            # Don't specify dst argument.
            if rank == 0:
                gather_list = [one.clone() for _ in group]
                dist.gather(one * rank, gather_list=gather_list)
                for i in group:
                    self.assertEqual(gather_list[i], one * i)
            else:
                dist.gather(one * rank)

        @skip_but_pass_in_sandcastle_if(
            BACKEND == "nccl", "Nccl does not support CPU tensors"
        )
        @skip_but_pass_in_sandcastle_if(
            BACKEND == "ucc", "CPU tensor ops not supported by UCP TL"
        )
        def test_gather(self):
            group, group_id, rank = self._init_global_test()
            self._test_gather_helper(group, group_id, rank)

        @skip_but_pass_in_sandcastle_if(
            BACKEND != "nccl", "Only Nccl supports CUDA gather"
        )
        @skip_if_no_gpu
        def test_gather_cuda(self):
            group, group_id, rank = self._init_global_test()
            rank_to_GPU = init_multigpu_helper(dist.get_world_size(), BACKEND)
            self._test_gather_helper(group, group_id, rank, True, rank_to_GPU)

        @skip_but_pass_in_sandcastle_if(
            BACKEND == "nccl", "Nccl does not support CPU tensors"
        )
        @skip_but_pass_in_sandcastle_if(
            BACKEND == "ucc", "CPU tensor ops not supported by UCP TL"
        )
        @skip_if_small_worldsize
        def test_gather_group(self):
            group, group_id, rank = self._init_group_test()
            self._test_gather_helper(group, group_id, rank)

        @skip_but_pass_in_sandcastle_if(
            BACKEND == "nccl", "Nccl does not support CPU tensors"
        )
        @skip_but_pass_in_sandcastle_if(
            BACKEND == "ucc", "CPU tensor ops not supported by UCP TL"
        )
        def test_gather_full_group(self):
            group, group_id, rank = self._init_full_group_test()
            self._test_gather_helper(group, group_id, rank)

        # ALL GATHER
        def _test_all_gather_helper(
            self, group, group_id, rank, cuda=False, rank_to_GPU=None, dtype=torch.float
        ):
            for dest in group:
                tensor = _build_tensor(dest + 1, rank, dtype=dtype)
                tensors = [_build_tensor(dest + 1, -1, dtype=dtype) for i in group]
                allgather = dist.all_gather
                if cuda:
                    tensor = tensor.cuda(rank_to_GPU[rank][0])
                    tensors = [t.cuda(rank_to_GPU[rank][0]) for t in tensors]
                if tensors[0].dtype == torch.complex64:
                    tensor_shapes = [torch.view_as_real(tensors[0]).shape]
                else:
                    tensor_shapes = [tensors[0].shape]
                self.call_dist_op(
                    ":all_gather",
                    False,
                    allgather,
                    tensors,
                    tensor,
                    group_id,
                    False,
                    tensor_shapes=tensor_shapes,
                )

                expected_tensors = [
                    _build_tensor(dest + 1, i, dtype=dtype) for i in group
                ]
                for t1, t2 in zip(tensors, expected_tensors):
                    self.assertEqual(t1, t2)

            self._barrier()

        @skip_but_pass_in_sandcastle_if(
            BACKEND == "nccl", "Nccl does not support CPU tensors"
        )
        def test_all_gather(self):
            group, group_id, rank = self._init_global_test()
            self._test_all_gather_helper(group, group_id, rank)

        @skip_but_pass_in_sandcastle_if(
            BACKEND != "nccl", "Only Nccl supports CUDA all gather"
        )
        @skip_if_no_gpu
        def test_all_gather_cuda(self):
            group, group_id, rank = self._init_global_test()
            rank_to_GPU = init_multigpu_helper(dist.get_world_size(), BACKEND)
            self._test_all_gather_helper(group, group_id, rank, True, rank_to_GPU)

        @skip_but_pass_in_sandcastle_if(
            BACKEND == "nccl", "Nccl does not support CPU tensors"
        )
        def test_all_gather_complex(self):
            group, group_id, rank = self._init_global_test()
            self._test_all_gather_helper(group, group_id, rank, dtype=torch.cfloat)

        @skip_but_pass_in_sandcastle_if(
            BACKEND != "nccl", "Only Nccl supports CUDA all gather"
        )
        @skip_if_no_gpu
        def test_all_gather_cuda_complex(self):
            group, group_id, rank = self._init_global_test()
            rank_to_GPU = init_multigpu_helper(dist.get_world_size(), BACKEND)
            self._test_all_gather_helper(
                group, group_id, rank, True, rank_to_GPU, dtype=torch.cfloat
            )

        @skip_if_small_worldsize
        @skip_but_pass_in_sandcastle_if(
            BACKEND == "nccl", "Nccl does not support CPU tensors"
        )
        def test_all_gather_group(self):
            group, group_id, rank = self._init_group_test()
            self._test_all_gather_helper(group, group_id, rank)

        @skip_but_pass_in_sandcastle_if(
            BACKEND == "nccl", "Nccl does not support CPU tensors"
        )
        def test_all_gather_full_group(self):
            group, group_id, rank = self._init_full_group_test()
            self._test_all_gather_helper(group, group_id, rank)

        @skip_but_pass_in_sandcastle_if(
            BACKEND != "nccl", "Only Nccl supports all_gather_v"
        )
        @skip_if_no_gpu
        def test_all_gather_v_cuda(self):
            self._barrier()
            group, group_id, rank = self._init_global_test()
            rank_to_GPU = init_multigpu_helper(dist.get_world_size(), BACKEND)
            device_id = rank_to_GPU[rank][0]

            output_split_sizes = [dst + 1 for dst in group]
            sum_len = sum(output_split_sizes)
            value = 2

            for async_val in [True, False]:
                tensor = (
                    torch.empty(
                        output_split_sizes[rank], sum_len, sum_len, dtype=torch.float
                    )
                    .fill_(value)
                    .cuda(device_id)
                )
                out_tensor = _build_tensor(sum_len, -1, device_id=device_id)

                req = dist.all_gather(
                    list(torch.split(out_tensor, output_split_sizes)),
                    tensor,
                    group_id,
                    async_val,
                )
                if async_val:
                    req.wait()

                expected_value = value
                expected_tensor = _build_tensor(
                    sum_len, expected_value, device_id=device_id
                )

                self.assertEqual(out_tensor, expected_tensor)
            self._barrier()

        # Test all_gather accepting single tensor as output
        def _all_gather_into_tensor_helper(
            self, tensor_out, tensor_in, group_id, rank, cuda=True, rank_to_GPU=None
        ):
            if cuda:
                tensor_in = tensor_in.cuda(rank_to_GPU[rank][0])
                tensor_out = tensor_out.cuda(rank_to_GPU[rank][0])
            if tensor_out.dtype == torch.complex64:
                tensor_shapes = [torch.view_as_real(tensor_in).shape]
            else:
                tensor_shapes = [tensor_in.shape]
            self.call_dist_op(
                ":all_gather_into_tensor",
                False,
                dist.all_gather_into_tensor,
                tensor_out,
                tensor_in,
                group_id,
                False,
                expect_event=False,
                tensor_shapes=tensor_shapes,
            )
            return tensor_out

        @skip_but_pass_in_sandcastle_if(
            BACKEND != "nccl", "Only Nccl supports CUDA all_gather_into_tensor"
        )
        @skip_if_no_gpu
        def test_all_gather_into_cat_tensor_cuda(self):
            group, group_id, rank = self._init_global_test()
            rank_to_GPU = init_multigpu_helper(dist.get_world_size(), BACKEND)
            size = 2
            tensor_in = torch.ones([size, size]) * rank
            # Concatenated output
            tensor_out = torch.ones([len(group) * size, size]) * (-1)
            tensor_out = self._all_gather_into_tensor_helper(
                tensor_out, tensor_in, group_id, rank, True, rank_to_GPU
            )

            # Check result
            # Concatenate all blocks into a bigger tensor
            expected_tensor = torch.cat([torch.ones([size, size]) * i for i in group])
            self.assertEqual(tensor_out, expected_tensor)
            self._barrier()

        @skip_but_pass_in_sandcastle_if(
            BACKEND != "nccl", "Only Nccl supports CUDA all_gather_into_tensor"
        )
        @skip_if_no_gpu
        def test_all_gather_into_stack_tensor_cuda(self):
            group, group_id, rank = self._init_global_test()
            rank_to_GPU = init_multigpu_helper(dist.get_world_size(), BACKEND)
            size = 2
            tensor_in = torch.ones([size, size]) * rank
            # Stacked output
            tensor_out = torch.ones([len(group), size, size]) * (-1)
            tensor_out = self._all_gather_into_tensor_helper(
                tensor_out, tensor_in, group_id, rank, True, rank_to_GPU
            )

            # Check result
            # Stack all blocks into a bigger tensor
            expected_tensor = torch.stack([torch.ones([size, size]) * i for i in group])
            self.assertEqual(tensor_out, expected_tensor)
            self._barrier()

        def _run_all_gather_coalesced_and_verify(
            self, output_tensor_lists, input_tensors, expected_tensors, group_id
        ):
            """
            Helper that runs all_gather_coalesced and returns true if output
            matches expectations.
            """
            tensor_shapes = []
            for input_tensor in input_tensors:
                if input_tensor.dtype == torch.complex64:
                    tensor_shapes.append(torch.view_as_real(input_tensor).shape)
                else:
                    tensor_shapes.append(input_tensor.shape)
            self.call_dist_op(
                ":all_gather",
                False,
                dist.all_gather_coalesced,
                output_tensor_lists,
                input_tensors,
                group_id,
                tensor_shapes=tensor_shapes,
            )

            for l1, l2 in zip(output_tensor_lists, expected_tensors):
                for t1, t2 in zip(l1, l2):
                    if not torch.equal(t1, t2):
                        return False
            return True

        def _test_all_gather_coalesced_helper(
            self, group, group_id, rank, dtype=torch.float
        ):
            # TODO: Instead we should probably go through _rank_not_in_group
            # mechanism to disable sending tensors
            if group_id is not None:
                for test_case_id in range(2, 5):
                    # Make sure we create tensors of incompatible sizes, e.g.
                    # [1], [2x2], [3x3x3] ... to be sent in one batch
                    input_tensors = [
                        _build_multidim_tensor(
                            tensor_id, tensor_id, rank + tensor_id, dtype=dtype
                        )
                        for tensor_id in range(1, test_case_id)
                    ]
                    output_tensor_lists = [
                        [
                            _build_multidim_tensor(
                                tensor_id, tensor_id, -1, dtype=dtype
                            )
                            for tensor_id in range(1, test_case_id)
                        ]
                        for _ in group
                    ]
                    expected_tensors = [
                        [
                            _build_multidim_tensor(
                                tensor_id, tensor_id, rank_iter + tensor_id, dtype=dtype
                            )
                            for tensor_id in range(1, test_case_id)
                        ]
                        for rank_iter in group
                    ]
                    assert self._run_all_gather_coalesced_and_verify(
                        output_tensor_lists, input_tensors, expected_tensors, group_id
                    ), "output tensors do not match expected outputs"

            self._barrier()

        @skip_but_pass_in_sandcastle_if(
            BACKEND in DistTestCases.skip_collective["allgather_coalesced"],
            f"{BACKEND} does not support all_gather_coalesced",
        )
        def test_all_gather_coalesced_simple(self):
            group, group_id, rank = self._init_global_test()
            self._test_all_gather_coalesced_helper(group, group_id, rank)

        @skip_but_pass_in_sandcastle_if(
            BACKEND in DistTestCases.skip_collective["allgather_coalesced"],
            f"{BACKEND} does not support all_gather_coalesced",
        )
        def test_all_gather_coalesced_complex(self):
            group, group_id, rank = self._init_global_test()
            self._test_all_gather_coalesced_helper(
                group, group_id, rank, dtype=torch.cfloat
            )

        @skip_if_small_worldsize
        @skip_but_pass_in_sandcastle_if(
            BACKEND in DistTestCases.skip_collective["allgather_coalesced"],
            f"{BACKEND} does not support all_gather_coalesced",
        )
        def test_all_gather_coalesced_group(self):
            group, group_id, rank = self._init_group_test()
            self._test_all_gather_coalesced_helper(group, group_id, rank)

        @skip_but_pass_in_sandcastle_if(
            BACKEND in DistTestCases.skip_collective["allgather_coalesced"],
            f"{BACKEND} does not support all_gather_coalesced",
        )
        def test_all_gather_coalesced_full_group(self):
            group, group_id, rank = self._init_full_group_test()
            self._test_all_gather_coalesced_helper(group, group_id, rank)

        @skip_but_pass_in_sandcastle_if(
            BACKEND in DistTestCases.skip_collective["allgather_coalesced"],
            f"{BACKEND} does not support all_gather_coalesced",
        )
        def test_all_gather_coalesced_with_empty(self):
            group, group_id, rank = self._init_global_test()
            input_tensors = [
                rank * torch.ones([2, 2]),
                torch.ones([0]),
                (rank + 1) * torch.ones([3, 3]),
                torch.ones([0]),
                torch.ones([0]),
            ]
            output_tensors_lists = [
                [
                    -1 * torch.ones([2, 2]),
                    -1 * torch.ones([0]),
                    -1 * torch.ones([3, 3]),
                    -1 * torch.ones([0]),
                    -1 * torch.ones([0]),
                ]
                for _ in group
            ]
            expected_tensors = [
                [
                    r * torch.ones([2, 2]),
                    torch.ones([0]),
                    (r + 1) * torch.ones([3, 3]),
                    torch.ones([0]),
                    torch.ones([0]),
                ]
                for r in group
            ]
            assert self._run_all_gather_coalesced_and_verify(
                output_tensors_lists, input_tensors, expected_tensors, group_id
            )
            self._barrier()

        # AllToAll
        def _test_all_to_all_single_equal_split_helper(
            self, group, group_id, rank, cuda=False, rank_to_GPU=None, dtype=torch.float
        ):
            if group_id is not None:
                size = len(group)
                in_tensor = torch.ones([size, size], dtype=dtype) * rank
                expected_tensor = torch.cat(
                    [torch.ones([1, size], dtype=dtype) * i for i in group]
                )
                out_tensor = torch.ones([size, size], dtype=dtype) * -1
                if cuda:
                    in_tensor = in_tensor.cuda(rank_to_GPU[rank][0])
                    expected_tensor = expected_tensor.cuda(rank_to_GPU[rank][0])
                    out_tensor = out_tensor.cuda(rank_to_GPU[rank][0])
                if dtype == torch.complex64:
                    tensor_shapes = [torch.view_as_real(in_tensor).shape]
                else:
                    tensor_shapes = [in_tensor.shape]
                self.call_dist_op(
                    ":all_to_all",
                    False,
                    dist.all_to_all_single,
                    out_tensor,
                    in_tensor,
                    group=group_id,
                    tensor_shapes=tensor_shapes,
                )
                self.assertEqual(out_tensor, expected_tensor)
            self._barrier()

        def _test_all_to_all_single_unequal_split_helper(
            self, group, group_id, rank, cuda=False, rank_to_GPU=None, dtype=torch.float
        ):
            if group_id is not None:
                size = len(group)
                in_splits = [i + 1 for i in group]
                out_splits = [rank + 1 for _ in group]
                in_tensor = torch.ones([sum(in_splits), size], dtype=dtype) * rank
                out_tensor = torch.ones([(rank + 1) * size, size], dtype=dtype)
                expected_tensor = torch.cat(
                    [torch.ones([rank + 1, size], dtype=dtype) * i for i in group]
                )
                if cuda:
                    in_tensor = in_tensor.cuda(rank_to_GPU[rank][0])
                    expected_tensor = expected_tensor.cuda(rank_to_GPU[rank][0])
                    out_tensor = out_tensor.cuda(rank_to_GPU[rank][0])
                dist.all_to_all_single(
                    out_tensor, in_tensor, out_splits, in_splits, group=group_id
                )
                self.assertEqual(out_tensor, expected_tensor)
            self._barrier()

        def _test_all_to_all_helper(
            self,
            group,
            group_id,
            rank,
            cuda=False,
            rank_to_GPU=None,
            dtype=torch.float,
        ):
            if group_id is not None:
                size = len(group)
                in_splits = [i + 1 for i in group]
                in_tensors = [
                    torch.ones([in_splits[i], size], dtype=dtype) * rank
                    for i, _ in enumerate(group)
                ]
                out_tensors = [
                    torch.ones([(rank + 1), size], dtype=dtype) for _ in group
                ]
                expected_tensors = [
                    torch.ones([rank + 1, size], dtype=dtype) * i for i in group
                ]
                if cuda:
                    in_tensors = [t.cuda(rank_to_GPU[rank][0]) for t in in_tensors]
                    expected_tensors = [
                        t.cuda(rank_to_GPU[rank][0]) for t in expected_tensors
                    ]
                    out_tensors = [t.cuda(rank_to_GPU[rank][0]) for t in out_tensors]
                dist.all_to_all(out_tensors, in_tensors, group=group_id)
                for t1, t2 in zip(out_tensors, expected_tensors):
                    self.assertEqual(t1, t2)
            self._barrier()

        @skip_but_pass_in_sandcastle_if(
            BACKEND != "mpi", "Only MPI supports CPU all_to_all_single"
        )
        def test_all_to_all_single_equal_split(self):
            group, group_id, rank = self._init_global_test()
            self._test_all_to_all_single_equal_split_helper(group, group_id, rank)

        @skip_but_pass_in_sandcastle_if(
            BACKEND != "nccl", "Only Nccl supports CUDA all_to_all_single"
        )
        @skip_if_no_gpu
        def test_all_to_all_single_equal_split_cuda(self):
            group, group_id, rank = self._init_global_test()
            rank_to_GPU = init_multigpu_helper(dist.get_world_size(), BACKEND)
            self._test_all_to_all_single_equal_split_helper(
                group,
                group_id,
                rank,
                True,
                rank_to_GPU,
            )

        @skip_but_pass_in_sandcastle_if(
            BACKEND != "mpi", "Only MPI supports CPU all_to_all_single"
        )
        def test_all_to_all_single_equal_split_complex(self):
            group, group_id, rank = self._init_global_test()
            self._test_all_to_all_single_equal_split_helper(
                group, group_id, rank, dtype=torch.cfloat
            )

        @skip_but_pass_in_sandcastle_if(
            BACKEND != "nccl", "Only Nccl supports CUDA all_to_all_single"
        )
        @skip_if_no_gpu
        def test_all_to_all_single_equal_split_cuda_complex(self):
            group, group_id, rank = self._init_global_test()
            rank_to_GPU = init_multigpu_helper(dist.get_world_size(), BACKEND)
            self._test_all_to_all_single_equal_split_helper(
                group, group_id, rank, True, rank_to_GPU, dtype=torch.cfloat
            )

        @skip_but_pass_in_sandcastle_if(
            BACKEND != "mpi", "Only MPI supports CPU all_to_all_single"
        )
        def test_all_to_all_single_unequal_split(self):
            group, group_id, rank = self._init_global_test()
            self._test_all_to_all_single_unequal_split_helper(group, group_id, rank)

        @skip_but_pass_in_sandcastle_if(
            BACKEND != "nccl", "Only Nccl supports CUDA all_to_all_single"
        )
        @skip_if_no_gpu
        def test_all_to_all_single_unequal_split_cuda(self):
            group, group_id, rank = self._init_global_test()
            rank_to_GPU = init_multigpu_helper(dist.get_world_size(), BACKEND)
            self._test_all_to_all_single_unequal_split_helper(
                group,
                group_id,
                rank,
                True,
                rank_to_GPU,
            )

        @skip_but_pass_in_sandcastle_if(
            BACKEND != "mpi", "Only MPI supports CPU all_to_all_single"
        )
        def test_all_to_all_single_unequal_split_complex(self):
            group, group_id, rank = self._init_global_test()
            self._test_all_to_all_single_unequal_split_helper(
                group, group_id, rank, dtype=torch.cfloat
            )

        @skip_but_pass_in_sandcastle_if(
            BACKEND != "nccl", "Only Nccl supports CUDA all_to_all_single"
        )
        @skip_if_no_gpu
        def test_all_to_all_single_unequal_split_cuda_complex(self):
            group, group_id, rank = self._init_global_test()
            rank_to_GPU = init_multigpu_helper(dist.get_world_size(), BACKEND)
            self._test_all_to_all_single_unequal_split_helper(
                group,
                group_id,
                rank,
                True,
                rank_to_GPU,
                dtype=torch.cfloat,
            )

        @skip_but_pass_in_sandcastle_if(
            BACKEND != "mpi", "Only MPI supports all_to_all"
        )
        def test_all_to_all(self):
            group, group_id, rank = self._init_global_test()
            self._test_all_to_all_helper(group, group_id, rank)

        @skip_but_pass_in_sandcastle_if(
            BACKEND != "nccl", "Only NCCL supports CUDA all_to_all"
        )
        @skip_if_rocm_multiprocess
        def test_all_to_all_cuda(self):
            group, group_id, rank = self._init_global_test()
            rank_to_GPU = init_multigpu_helper(dist.get_world_size(), BACKEND)
            self._test_all_to_all_helper(group, group_id, rank, True, rank_to_GPU)

        @skip_but_pass_in_sandcastle_if(
            BACKEND != "mpi", "Only MPI supports all_to_all"
        )
        def test_all_to_all_complex(self):
            group, group_id, rank = self._init_global_test()
            self._test_all_to_all_helper(group, group_id, rank, dtype=torch.cfloat)

        @skip_but_pass_in_sandcastle_if(
            BACKEND != "nccl", "Only NCCL supports CUDA all_to_all"
        )
        @skip_if_rocm_multiprocess
        def test_all_to_all_cuda_complex(self):
            group, group_id, rank = self._init_global_test()
            rank_to_GPU = init_multigpu_helper(dist.get_world_size(), BACKEND)
            self._test_all_to_all_helper(
                group, group_id, rank, True, rank_to_GPU, dtype=torch.cfloat
            )

        @skip_but_pass_in_sandcastle_if(
            BACKEND != "mpi", "Only MPI supports CPU all_to_all_single"
        )
        @skip_if_small_worldsize
        def test_all_to_all_single_equal_split_group(self):
            group, group_id, rank = self._init_group_test()
            self._test_all_to_all_single_equal_split_helper(group, group_id, rank)

        @skip_but_pass_in_sandcastle_if(
            BACKEND != "nccl", "Only Nccl supports CUDA all_to_all_single"
        )
        @skip_if_no_gpu
        @skip_if_small_worldsize
        def test_all_to_all_single_equal_split_group_cuda(self):
            group, group_id, rank = self._init_group_test()
            rank_to_GPU = init_multigpu_helper(dist.get_world_size(), BACKEND)
            self._test_all_to_all_single_equal_split_helper(
                group,
                group_id,
                rank,
                True,
                rank_to_GPU,
            )

        @skip_but_pass_in_sandcastle_if(
            BACKEND != "mpi", "Only MPI supports CPU all_to_all_single"
        )
        @skip_if_small_worldsize
        def test_all_to_all_single_unequal_split_group(self):
            group, group_id, rank = self._init_group_test()
            self._test_all_to_all_single_unequal_split_helper(group, group_id, rank)

        @skip_but_pass_in_sandcastle_if(
            BACKEND != "nccl", "Only Nccl supports CUDA all_to_all_single"
        )
        @skip_if_no_gpu
        @skip_if_small_worldsize
        def test_all_to_all_single_unequal_split_group_cuda(self):
            group, group_id, rank = self._init_global_test()
            rank_to_GPU = init_multigpu_helper(dist.get_world_size(), BACKEND)
            self._test_all_to_all_single_unequal_split_helper(
                group,
                group_id,
                rank,
                True,
                rank_to_GPU,
            )

        @skip_but_pass_in_sandcastle_if(
            BACKEND != "mpi", "Only MPI supports all_to_all"
        )
        @skip_if_small_worldsize
        def test_all_to_all_group(self):
            group, group_id, rank = self._init_group_test()
            self._test_all_to_all_helper(group, group_id, rank)

        @skip_but_pass_in_sandcastle_if(
            BACKEND != "nccl", "Only Nccl supports CUDA all_to_all_single"
        )
        @skip_if_small_worldsize
        @skip_if_rocm_multiprocess
        def test_all_to_all_group_cuda(self):
            group, group_id, rank = self._init_group_test()
            rank_to_GPU = init_multigpu_helper(dist.get_world_size(), BACKEND)
            self._test_all_to_all_helper(group, group_id, rank, True, rank_to_GPU)

        @skip_but_pass_in_sandcastle_if(
            BACKEND != "mpi", "Only MPI supports CPU all_to_all_single"
        )
        def test_all_to_all_single_equal_split_full_group(self):
            group, group_id, rank = self._init_full_group_test()
            self._test_all_to_all_single_equal_split_helper(group, group_id, rank)

        @skip_but_pass_in_sandcastle_if(
            BACKEND != "nccl", "Only Nccl supports CUDA all_to_all_single"
        )
        @skip_if_no_gpu
        def test_all_to_all_single_equal_split_full_group_cuda(self):
            group, group_id, rank = self._init_full_group_test()
            rank_to_GPU = init_multigpu_helper(dist.get_world_size(), BACKEND)
            self._test_all_to_all_single_equal_split_helper(
                group,
                group_id,
                rank,
                True,
                rank_to_GPU,
            )

        @skip_but_pass_in_sandcastle_if(
            BACKEND != "mpi", "Only MPI supports CPU all_to_all_single"
        )
        def test_all_to_all_single_unequal_split_full_group(self):
            group, group_id, rank = self._init_full_group_test()
            self._test_all_to_all_single_unequal_split_helper(group, group_id, rank)

        @skip_but_pass_in_sandcastle_if(
            BACKEND != "nccl", "Only Nccl supports CUDA all_to_all_single"
        )
        @skip_if_no_gpu
        def test_all_to_all_single_unequal_split_full_group_cuda(self):
            group, group_id, rank = self._init_full_group_test()
            rank_to_GPU = init_multigpu_helper(dist.get_world_size(), BACKEND)
            self._test_all_to_all_single_unequal_split_helper(
                group,
                group_id,
                rank,
                True,
                rank_to_GPU,
            )

        @skip_but_pass_in_sandcastle_if(
            BACKEND != "mpi", "Only MPI supports all_to_all"
        )
        def test_all_to_all_full_group(self):
            group, group_id, rank = self._init_full_group_test()
            self._test_all_to_all_helper(group, group_id, rank)

        @skip_but_pass_in_sandcastle_if(
            BACKEND != "nccl", "Only NCCL supports CUDA all_to_all"
        )
        @skip_if_rocm_multiprocess
        def test_all_to_all_full_group_cuda(self):
            group, group_id, rank = self._init_full_group_test()
            rank_to_GPU = init_multigpu_helper(dist.get_world_size(), BACKEND)
            self._test_all_to_all_helper(group, group_id, rank, True, rank_to_GPU)

        # BARRIER
        def _test_barrier_helper(
            self, group, group_id, rank, cuda=False, rank_to_GPU=None
        ):
            WAIT_TIME = 0.3  # seconds

            for dest in group:
                expected_time = torch.DoubleTensor(1).fill_(0.0)
                if cuda:
                    expected_time = expected_time.cuda(rank_to_GPU[rank][0])
                if dest == rank:
                    expected_time.fill_(time.time() + WAIT_TIME)
                    dist.broadcast(expected_time, dest, group_id)
                    time.sleep(WAIT_TIME + 0.1)  # sleep a little bit longer
                    dist.barrier(group_id)
                else:
                    dist.broadcast(expected_time, dest, group_id)
                    dist.barrier(group_id)
                    self.assertGreaterAlmostEqual(
                        float(time.time()),
                        float(expected_time[0]),
                        msg="destination rank: %d, my rank: %d" % (dest, rank)
                        + " (if you see this failure, please report in #14554)",
                    )

            # Use higher timeout for the instance where the test runs
            # against a subgroup and uses a CUDA tensor for expected time.
            # The CUDA initialization for the participating processes can
            # take long enough for the barrier timeout to trigger on the
            # process that doesn't participate in the group.
            self._barrier(timeout=20)

        @skip_if_no_gpu
        @skip_but_pass_in_sandcastle_if(
            BACKEND == "mpi", "MPI doesn't supports GPU barrier"
        )
        @skip_but_pass_in_sandcastle_if(
            BACKEND == "ucc" and IS_SANDCASTLE, "Skipped internally"
        )
        def test_barrier_cuda(self):
            group, group_id, rank = self._init_global_test()
            rank_to_GPU = init_multigpu_helper(dist.get_world_size(), BACKEND)
            self._test_barrier_helper(group, group_id, rank, True, rank_to_GPU)

        @skip_if_small_worldsize
        @skip_if_no_gpu
        @skip_but_pass_in_sandcastle_if(
            BACKEND == "mpi", "MPI doesn't supports GPU barrier"
        )
        def test_barrier_group_cuda(self):
            group, group_id, rank = self._init_group_test()
            rank_to_GPU = init_multigpu_helper(dist.get_world_size(), BACKEND)
            self._test_barrier_helper(group, group_id, rank, True, rank_to_GPU)

        @skip_if_small_worldsize
        @skip_if_no_gpu
        @skip_but_pass_in_sandcastle_if(
            BACKEND == "mpi", "MPI doesn't supports GPU barrier"
        )
        def test_barrier_full_group_cuda(self):
            group, group_id, rank = self._init_full_group_test()
            rank_to_GPU = init_multigpu_helper(dist.get_world_size(), BACKEND)
            self._test_barrier_helper(group, group_id, rank, True, rank_to_GPU)

        @skip_but_pass_in_sandcastle_if(
            BACKEND in DistTestCases.skip_collective["cpu barrier"],
            f"{BACKEND} does not support CPU barrier",
        )
        def test_barrier(self):
            group, group_id, rank = self._init_global_test()
            self._test_barrier_helper(group, group_id, rank)

        @skip_if_small_worldsize
        @skip_but_pass_in_sandcastle_if(
            BACKEND in DistTestCases.skip_collective["cpu barrier"],
            f"{BACKEND} does not support CPU barrier",
        )
        def test_barrier_group(self):
            group, group_id, rank = self._init_group_test()
            self._test_barrier_helper(group, group_id, rank)

        @skip_but_pass_in_sandcastle_if(
            BACKEND in DistTestCases.skip_collective["cpu barrier"],
            f"{BACKEND} does not support CPU barrier",
        )
        def test_barrier_full_group(self):
            group, group_id, rank = self._init_full_group_test()
            self._test_barrier_helper(group, group_id, rank)

        def _model_step(self, model):
            for param in model.parameters():
                if param.grad is not None:
                    with torch.no_grad():
                        param += param.grad
                    param.grad = None

        def _model_step_with_zero_grad(self, model):
            for param in model.parameters():
                if param.grad is not None:
                    with torch.no_grad():
                        param += param.grad
                    param.grad.requires_grad_(False)
                    param.grad.zero_()

        def _prepare_dummy_data(self, local_bs):
            # global_bs for DDP should be divisible by WORLD_SIZE
            world_size = int(os.environ["WORLD_SIZE"])
            global_bs = world_size * local_bs
            input_cpu = torch.randn(global_bs, 2)
            target = torch.randn(global_bs, 4)
            loss = nn.MSELoss()
            return global_bs, input_cpu, target, loss

        # END TO END TEST FOR DISTRIBUTEDDATAPARALLEL
        def _test_DDP_helper(
            self, model, input_var, target, loss, scale_factor=1.0, memory_format=None
        ):
            model.train()
            output = model(input_var)
            l = loss(output, target) * scale_factor
            l.backward()
            if memory_format is not None:
                self.assertTrue(output.is_contiguous(memory_format=memory_format))

        def _assert_equal_param(self, param_gpu, param_DDP):
            self.assertEqual(len(param_gpu), len(param_DDP))
            for p_gpu, p_DDP in zip(param_gpu, param_DDP):
                self.assertEqual(p_gpu, p_DDP)

        def _test_DDP_niter(
            self,
            model_base,
            model_DDP,
            input,
            target,
            loss,
            local_bs,
            rank,
            batch_size,
            test_save,
            offset=None,
            world_size=0,
            zero_grad=False,
            memory_format=None,
            n_iter=5,
        ):
            for idx in range(n_iter):
                # single cpu/gpu training
                self._test_DDP_helper(
                    model_base, input, target, loss, memory_format=memory_format
                )

                if offset is None:
                    offset = rank * local_bs

                # DDP training, DDP scatters subsets of input_cpu to nodes/GPUs
                self._test_DDP_helper(
                    model_DDP,
                    input[offset : offset + local_bs],
                    target[offset : offset + local_bs],
                    loss,
                    world_size * local_bs / batch_size if world_size != 0 else 1,
                    memory_format=memory_format,
                )

                # Update weights and run a second iteration to shake out errors
                if zero_grad:
                    self._model_step_with_zero_grad(model_base)
                    self._model_step_with_zero_grad(model_DDP)
                else:
                    self._model_step(model_base)
                    self._model_step(model_DDP)
                self._assert_equal_param(
                    list(model_base.parameters()), list(model_DDP.module.parameters())
                )

                # Shuffle the input so that DDP input is different
                input = input[torch.randperm(batch_size)]

                # save the model in the middle and reload
                if test_save and idx == 2 and INIT_METHOD.startswith("file://"):
                    with tempfile.NamedTemporaryFile() as tmp:
                        if sys.platform == "win32":
                            torch.save(model_DDP, tmp)
                            tmp.seek(0)
                            # weights_only=False as this is legacy code that saves the model
                            model_DDP = torch.load(tmp, weights_only=False)
                        else:
                            torch.save(model_DDP, tmp.name)
                            # weights_only=False as this is legacy code that saves the model
                            model_DDP = torch.load(tmp.name, weights_only=False)

            with tempfile.TemporaryFile() as tmp_file:
                torch.save(model_DDP, tmp_file)
                tmp_file.seek(0)
                # weights_only=False as this is legacy code that saves the model
                saved_model = torch.load(tmp_file, weights_only=False)
            for k in model_DDP.state_dict():
                self.assertEqual(model_DDP.state_dict()[k], saved_model.state_dict()[k])

        def _test_DistributedDataParallel(
            self,
            gpu_subset,
            rank,
            output_device=None,
            gradient_as_bucket_view=False,
            static_graph=False,
            set_static_graph_twice=False,
        ):
            # Run a simple end to end DDP model, use result of single node model
            # as baseline

            # cpu training setup
            model = DDP_NET

            # single gpu training setup
            model_gpu = copy.deepcopy(model)
            model_gpu.cuda(gpu_subset[0])

            # DDP training setup
            model_DDP = copy.deepcopy(model)
            model_DDP.cuda(gpu_subset[0])
            model_DDP = nn.parallel.DistributedDataParallel(
                model_DDP,
                device_ids=gpu_subset,
                gradient_as_bucket_view=gradient_as_bucket_view,
                static_graph=static_graph,
            )

            if set_static_graph_twice:
                model_DDP._set_static_graph()

            # test serializable/unserializable
            with tempfile.NamedTemporaryFile() as tmp:
                if sys.platform == "win32":
                    torch.save(model_DDP, tmp)
                    tmp.seek(0)
                    # weights_only=False as this is legacy code that saves the model
                    model_DDP = torch.load(tmp, weights_only=False)
                else:
                    torch.save(model_DDP, tmp.name)
                    # weights_only=False as this is legacy code that saves the model
                    model_DDP = torch.load(tmp.name, weights_only=False)

            # dummy data initialization
            local_bs = len(gpu_subset)
            global_bs, input_cpu, target, loss = self._prepare_dummy_data(local_bs)

            # check two model parameters over 5 iterations
            self._test_DDP_niter(
                model_gpu,
                model_DDP,
                input_cpu.cuda(gpu_subset[0]),
                target.cuda(gpu_subset[0]),
                loss,
                local_bs,
                rank,
                global_bs,
                True,
            )
            self._barrier()

        def _test_DistributedDataParallelCPU(self, gradient_as_bucket_view=False):
            # Run a simple end to end DDP-CPU model, use result of single node
            # model as baseline
            _group, _group_id, rank = self._init_global_test()

            # cpu training setup
            model_base = DDP_NET

            # DDP-CPU training setup
            model_DDP = copy.deepcopy(model_base)
            model_DDP = nn.parallel.DistributedDataParallel(
                model_DDP, gradient_as_bucket_view=gradient_as_bucket_view
            )

            # dummy data initialization
            local_bs = 2
            global_bs, input_cpu, target, loss = self._prepare_dummy_data(local_bs)

            # check two model parameters over 5 iterations
            self._test_DDP_niter(
                model_base,
                model_DDP,
                input_cpu,
                target,
                loss,
                local_bs,
                rank,
                global_bs,
                False,
                zero_grad=True,
            )
            self._barrier()

            return model_DDP

        @skip_but_pass_in_sandcastle_if(
            BACKEND == "nccl", "nccl does not support DDP on CPU models"
        )
        def test_DistributedDataParallelCPU(self):
            self._test_DistributedDataParallelCPU()

        @skip_but_pass_in_sandcastle_if(
            BACKEND == "nccl", "nccl does not support DDP on CPU models"
        )
        def test_DistributedDataParallelCPU_grad_is_view(self):
            self._test_DistributedDataParallelCPU(gradient_as_bucket_view=True)

        @skip_but_pass_in_sandcastle_if(
            BACKEND not in DistTestCases.backend_feature["ddp"],
            f"The {BACKEND} backend does not support DistributedDataParallel",
        )
        def test_DistributedDataParallel_requires_grad(self):
            # a module without gradients shouldn't be accepted
            self.assertRaises(
                RuntimeError, lambda: nn.parallel.DistributedDataParallel(nn.Module())
            )
            self._barrier()

        @skip_but_pass_in_sandcastle_if(
            BACKEND not in DistTestCases.backend_feature["ddp"],
            f"The {BACKEND} backend does not support DistributedDataParallel",
        )
        @skip_if_lt_x_gpu(int(os.environ["WORLD_SIZE"]))
        def test_ddp_zero_output_features(self):
            class ToyModel(nn.Module):
                def __init__(self) -> None:
                    super().__init__()
                    self.net1 = nn.Linear(10, 10)
                    self.relu = nn.ReLU()
                    self.net2 = nn.Linear(10, 0)

            model = ToyModel().to(self.rank)
            nn.parallel.DistributedDataParallel(
                model, device_ids=[self.rank]
            )

        @skip_but_pass_in_sandcastle_if(BACKEND == "nccl", "Gloo-only test")
        def test_ddp_create_graph(self):
            class Model(nn.Module):
                def __init__(self) -> None:
                    super().__init__()
                    self.p = nn.Parameter(torch.tensor(1.0))

                def forward(self):
                    return self.p.pow(2)

            model = Model()
            ddp_model = torch.nn.parallel.DistributedDataParallel(model)
            for _ in range(6):
                # Verify DDP doesn't throw when ran with create_graph=True.
                # Although we do warn about potential issues, please see
                # https://github.com/pytorch/pytorch/issues/63929 for details.
                ddp_model().backward(create_graph=True)
                # grad tensors should require grad.
                self.assertTrue(
                    all(param.requires_grad for param in ddp_model.parameters())
                )

        @skip_but_pass_in_sandcastle_if(
            BACKEND not in DistTestCases.backend_feature["ddp"],
            f"The {BACKEND} backend does not support DistributedDataParallel",
        )
        @skip_if_lt_x_gpu(int(os.environ["WORLD_SIZE"]))
        def test_DistributedDataParallel_non_default_stream(self):
            stream = torch.cuda.Stream(self.rank)
            rank = self.rank
            with torch.cuda.stream(stream):
                net = torch.nn.parallel.DistributedDataParallel(
                    torch.nn.Linear(1, 1, bias=False).cuda(rank), device_ids=[rank]
                )
                for i in range(1000):
                    # Clear gradients manually
                    grad = net.module.weight.grad
                    if grad is not None:
                        grad.requires_grad_(False)
                        grad.zero_()
                    # Forward + BW
                    batch = torch.tensor([rank]).float().cuda(rank)
                    loss = net(batch).sum()
                    loss.backward()
                    # For each worker, the gradient on the weight should be worker_rank.
                    grad = net.module.weight.grad
                    avg = grad.clone()
                    # All-reducing the gradient averages should give us the gradient
                    # average. If not, then one of the workers has not correctly
                    # written back the averaged gradient before this all-reduce call.
                    dist.all_reduce(avg)
                    world_size = int(os.environ["WORLD_SIZE"])
                    avg.div_(world_size)
                    expected_grad = sum(i for i in range(world_size)) / world_size
                    self.assertEqual(
                        avg[0, 0],
                        expected_grad,
                        msg=f"Expected gradient of {expected_grad} but got {avg} on rank {self.rank}",
                    )

        @skip_but_pass_in_sandcastle_if(
            BACKEND not in DistTestCases.backend_feature["cuda"],
            f"The {BACKEND} backend does not support DDP communication hook on CUDA devices",
        )
        @skip_if_lt_x_gpu(int(os.environ["WORLD_SIZE"]))
        def test_ddp_comm_hook_logging(self):
            hooks = [
                default.allreduce_hook,
                default.fp16_compress_hook,
                powerSGD.powerSGD_hook,
                powerSGD.batched_powerSGD_hook,
                quantization_hooks.quantization_pertensor_hook,
                quantization_hooks.quantization_perchannel_hook,
            ]

            cpp_builtin_hooks = [
                dist.BuiltinCommHookType.ALLREDUCE,
                dist.BuiltinCommHookType.FP16_COMPRESS,
            ]

            for hook in hooks:
                ddp_model = torch.nn.parallel.DistributedDataParallel(
                    torch.nn.Linear(1, 1, bias=False).cuda(self.rank),
                    device_ids=[self.rank],
                )
                ddp_logging_data = ddp_model._get_ddp_logging_data()
                # Hook not registered yet, so should be empty
                self.assertEqual(ddp_logging_data.get("comm_hook"), None)
                ddp_model.register_comm_hook(None, hook)
                ddp_logging_data = ddp_model._get_ddp_logging_data()
                self.assertEqual(ddp_logging_data.get("comm_hook"), hook.__qualname__)

            for hook in cpp_builtin_hooks:
                ddp_model = torch.nn.parallel.DistributedDataParallel(
                    torch.nn.Linear(1, 1, bias=False).cuda(self.rank),
                    device_ids=[self.rank],
                )
                ddp_logging_data = ddp_model._get_ddp_logging_data()
                # Hook not registered yet, so should be empty
                self.assertEqual(ddp_logging_data.get("comm_hook"), None)
                ddp_model._register_builtin_comm_hook(hook)
                ddp_logging_data = ddp_model._get_ddp_logging_data()
                self.assertEqual(ddp_logging_data.get("comm_hook"), str(hook))

            # No hook registered
            ddp_model = torch.nn.parallel.DistributedDataParallel(
                torch.nn.Linear(1, 1, bias=False).cuda(self.rank),
                device_ids=[self.rank],
            )
            ddp_logging_data = ddp_model._get_ddp_logging_data()
            # Hook not registered yet, so should be empty
            self.assertEqual(ddp_logging_data.get("comm_hook"), None)
            # After second forward pass, hook should still be empty string
            for _ in range(2):
                inp = torch.ones(1, 1, device=self.rank)
                loss = ddp_model(inp).sum()
                loss.backward()

            ddp_logging_data = ddp_model._get_ddp_logging_data()
            # Note: DETAIL debug mode logs DDP logging data to stdout and
            # thus accesses std::map, which fills in a default value for the
            # type if it didn't exist.
            self.assertEqual(ddp_logging_data.get("comm_hook", ""), "")

        def _test_ddp_hook_with_optimizer_parity(
            self,
            grad_as_bucket_view,
            static_graph,
            optim_cls,
            optimize_subset,
            *functional_optim_args,
            **functional_optim_kwargs,
        ):
            rank = self.rank
            torch.cuda.set_device(rank)
            torch.manual_seed(rank)
            torch.cuda.manual_seed(rank)
            models_to_test = [
                (LargeNet(), torch.randn(1, 1000).cuda()),
            ]
            if HAS_TORCHVISION:
                models_to_test.append(
                    (torchvision.models.resnet50(), torch.randn(1, 3, 3, 1000).cuda())
                )
            for (model, inp) in models_to_test:
                # Enable determinism in cudnn operators
                with torch.backends.cudnn.flags(
                    enabled=True, deterministic=True, benchmark=False
                ):
                    # Create DDP model that runs optimizer in fused fashion.
                    ddp_model_with_optimizer_hook = (
                        torch.nn.parallel.DistributedDataParallel(
                            copy.deepcopy(model).cuda(),
                            device_ids=[self.rank],
                            gradient_as_bucket_view=grad_as_bucket_view,
                            static_graph=static_graph,
                        )
                    )

                    # Create DDP model with no hook that does optimizer after
                    # backward.
                    ddp_model_with_no_hook = torch.nn.parallel.DistributedDataParallel(
                        copy.deepcopy(model).cuda(),
                        device_ids=[self.rank],
                        gradient_as_bucket_view=grad_as_bucket_view,
                        static_graph=static_graph,
                    )
                    hook_params = ddp_model_with_optimizer_hook.parameters()
                    no_hook_params = ddp_model_with_no_hook.parameters()
                    if optimize_subset:
                        hook_params = list(hook_params)
                        no_hook_params = list(no_hook_params)
                        self.assertGreater(len(hook_params), 0)
                        hook_params = [hook_params[0]]
                        no_hook_params = [no_hook_params[0]]

                    # Register a fused optimizer that will run optimizer in step
                    # with allreduce.

                    if optimize_subset:
                        # API where optim_params is specified.
                        ddp_model_with_optimizer_hook._register_fused_optim(
                            optim_cls,
                            *functional_optim_args,
                            optim_params=hook_params,
                            **functional_optim_kwargs,
                        )
                    else:
                        # API where optim_params is omitted
                        ddp_model_with_optimizer_hook._register_fused_optim(
                            optim_cls,
                            *functional_optim_args,
                            **functional_optim_kwargs,
                        )

                    optimizer_no_hook = optim_cls(
                        no_hook_params,
                        *functional_optim_args,
                        **functional_optim_kwargs,
                    )

                    # Verify parameters are equal initially.
                    for hook_param, allreduce_param in zip(
                        ddp_model_with_optimizer_hook.parameters(),
                        ddp_model_with_no_hook.parameters(),
                    ):
                        self.assertEqual(hook_param, allreduce_param)

                    # Save old parameters to later verify optimizer modified them.
                    opt_hook_init_params = copy.deepcopy(
                        list(ddp_model_with_optimizer_hook.parameters())
                    )

                    # Run optimizer with hook model.
                    for _ in range(6):
                        ddp_model_with_optimizer_hook.zero_grad()
                        out = ddp_model_with_optimizer_hook(inp)
                        loss = out.sum()
                        loss.backward()

                    dist.barrier()

                    # Run regular model.
                    for _ in range(6):
                        ddp_model_with_no_hook.zero_grad()
                        out = ddp_model_with_no_hook(inp)
                        loss = out.sum()
                        loss.backward()
                        optimizer_no_hook.step()

                    dist.barrier()

                    # Now verify parameters are equal.
                    for hook_param, allreduce_param in zip(
                        ddp_model_with_optimizer_hook.parameters(),
                        ddp_model_with_no_hook.parameters(),
                    ):
                        self.assertEqual(hook_param, allreduce_param)

                    # Verify optimizer modified appropriate parameter set,
                    # otherwise they'd be trivially equal above.
                    if optimize_subset:
                        self.assertNotEqual(
                            opt_hook_init_params[0],
                            next(iter(ddp_model_with_optimizer_hook.parameters())),
                        )
                        # Untouched params should be equal
                        self.assertEqual(
                            opt_hook_init_params[1:],
                            list(ddp_model_with_optimizer_hook.parameters())[1:],
                        )
                    else:
                        self.assertNotEqual(
                            opt_hook_init_params,
                            list(ddp_model_with_optimizer_hook.parameters()),
                        )
                    dist.barrier()

        """
        # Commenting out the following 3 tests as they cause Sandcastle jobs to fail
        # Failure signature:
        # AttributeError: type object 'TestDistBackendWithSpawn' has no attribute 'test_ddp_hook_with_optimizer_parity_adamw

        from torch.testing._internal.common_utils import parametrize

        @skip_but_pass_in_sandcastle_if(
            BACKEND == "nccl" or BACKEND == "ucc",
            "Issues with async error handling, see https://github.com/pytorch/pytorch/issues/73259",
        )
        @skip_if_lt_x_gpu(2)
        @parametrize("grad_as_bucket_view", [True, False])
        @parametrize("static_graph", [True, False])
        @parametrize("optimize_subset", [True, False])
        def test_ddp_hook_with_optimizer_parity_adamw(
            self,
            grad_as_bucket_view,
            static_graph,
            optimize_subset,
        ):
            adamw_lr = 1e-2
            adamw_betas = (0.9, 0.99)
            adamw_eps = 1e-6
            self._test_ddp_hook_with_optimizer_parity(
                grad_as_bucket_view,
                static_graph,
                torch.optim.AdamW,
                optimize_subset,
                adamw_lr,
                betas=adamw_betas,
                eps=adamw_eps,
            )

        @skip_but_pass_in_sandcastle_if(
            BACKEND == "nccl" or BACKEND == "ucc",
            "Issues with async error handling, see https://github.com/pytorch/pytorch/issues/73259",
        )
        @skip_if_lt_x_gpu(2)
        @parametrize("optimize_subset", [True, False])
        def test_ddp_hook_with_optimizer_parity_adam(self, optimize_subset):
            adam_lr = 1e-2
            adam_betas = (0.9, 0.99)
            adam_eps = 1e-6
            self._test_ddp_hook_with_optimizer_parity(
                True,  # grad as bucket view
                False,  # static graph
                torch.optim.Adam,
                optimize_subset,
                adam_lr,
                betas=adam_betas,
                eps=adam_eps,
            )

        @skip_but_pass_in_sandcastle_if(
            BACKEND == "nccl" or BACKEND == "ucc",
            "Issues with async error handling, see https://github.com/pytorch/pytorch/issues/73259",
        )
        @skip_if_lt_x_gpu(2)
        @parametrize("optimize_subset", [True, False])
        def test_ddp_hook_with_optimizer_parity_sgd(self, optimize_subset):
            sgd_lr = 1e-2
            sgd_momentum = 0.9
            sgd_weight_decay = 0.01
            # Not testing grad_as_bucket_view and static_graph as they are
            # tested in AdamW test above.
            self._test_ddp_hook_with_optimizer_parity(
                True,  # grad as bucket view
                False,  # static_graph
                torch.optim.SGD,
                optimize_subset,
                sgd_lr,
                momentum=sgd_momentum,
                weight_decay=sgd_weight_decay,
            )
        """

        @skip_if_lt_x_gpu(2)
        def test_get_data_parallel_params(self):
            torch.cuda.set_device(self.rank)
            model = TwoLinLayerNet().cuda()
            # Parameters to ignore are in the format {module_name}.{param_name}
            params_to_ignore = ["a.weight"]
            torch.nn.parallel.DistributedDataParallel._set_params_and_buffers_to_ignore_for_model(
                model, params_to_ignore
            )
            torch.nn.parallel.DistributedDataParallel(
                model, device_ids=[self.rank]
            )
            dp_params = torch.nn.parallel.DistributedDataParallel._get_data_parallel_params(
                model, named_params=True
            )
            for name, _ in dp_params:
                self.assertNotEqual(f"module.{params_to_ignore[0]}", name)

            # test named_params=False, just check if returns the expected
            # no of parameters.
            num_ddp_params = len(list(model.parameters())) - 1
            count = 0
            dp_params = torch.nn.parallel.DistributedDataParallel._get_data_parallel_params(model, named_params=False)
            for _ in dp_params:
                count += 1
            self.assertEqual(count, num_ddp_params)

        def _test_ddp_apply_optim_in_backward(
            self,
            optim_cls,
            optim_kwargs,
            init_before,
            gradient_as_bucket_view=True,
        ):
            # Need to seed to ensure inputs are unique across rank. Otherwise,
            # allreduce won't have any effect.
            torch.manual_seed(self.rank)
            torch.cuda.manual_seed(self.rank)
            torch.cuda.set_device(self.rank)

            # Test a simple linear as well as a ResNet model.
            models_to_test = [
                nn.Sequential(nn.Linear(3, 3), nn.Linear(3, 3), nn.Linear(3, 3)).cuda()
            ]
            if HAS_TORCHVISION:
                models_to_test.append(torchvision.models.resnet50().cuda())

            for j, model in enumerate(models_to_test):
                model_optim_in_bwd = copy.deepcopy(model)
                model = nn.parallel.DistributedDataParallel(
                    model,
                    device_ids=[self.rank],
                    gradient_as_bucket_view=gradient_as_bucket_view,
                )
                optim = optim_cls(model.parameters(), **optim_kwargs)
                if init_before:
                    _apply_optimizer_in_backward(
                        optimizer_class=optim_cls,
                        params=model_optim_in_bwd.parameters(),
                        optimizer_kwargs=optim_kwargs,
                    )
                model_optim_in_bwd = nn.parallel.DistributedDataParallel(
                    model_optim_in_bwd,
                    device_ids=[self.rank],
                    gradient_as_bucket_view=gradient_as_bucket_view,
                )
                if not init_before:
                    _apply_optimizer_in_backward(
                        optimizer_class=optim_cls,
                        params=model_optim_in_bwd.parameters(),
                        optimizer_kwargs=optim_kwargs,
                    )

                for p1, p2 in zip(model.parameters(), model_optim_in_bwd.parameters()):
                    self.assertEqual(p1, p2, "Parameters not initially equal!")
                # Enable determinism in cudnn operators
                with torch.backends.cudnn.flags(
                    enabled=True, deterministic=True, benchmark=False
                ):
                    for i in range(8):
                        inp = (
                            torch.randn(1, 3, 1000, 1000, device="cuda")
                            if j == 1
                            else torch.randn(10, 3, device="cuda")
                        )
                        model(inp).sum().backward()
                        optim.step()
                        model_optim_in_bwd(
                            inp
                        ).sum().backward()  # runs optimizer as well
                        for p1, p2 in zip(
                            model.parameters(), model_optim_in_bwd.parameters()
                        ):
                            self.assertEqual(
                                p1, p2, f"Params not equal at iteration {i}"
                            )
                            self.assertTrue(
                                p2.grad is None,
                                f"Optim in backward grad is not None at {i}",
                            )

                        # set_to_none for regular optimizer to match in backward
                        # case.
                        optim.zero_grad(set_to_none=True)

        @skip_if_lt_x_gpu(2)
        def test_ddp_apply_optim_in_backward(self):
            for optim_cls, init_before in itertools.product(
                [torch.optim.SGD, torch.optim.Adam], [True, False]
            ):
                with self.subTest(optim_cls=optim_cls):
                    self._test_ddp_apply_optim_in_backward(
                        optim_cls=optim_cls,
                        optim_kwargs={"lr": 0.03},
                        init_before=init_before,
                    )

        @skip_if_lt_x_gpu(2)
        def test_ddp_apply_optim_in_backward_grad_as_bucket_view_false(self):
            for init_before in [True, False]:
                self._test_ddp_apply_optim_in_backward(
                    optim_cls=torch.optim.SGD,
                    optim_kwargs={"lr": 0.03},
                    init_before=init_before,
                    gradient_as_bucket_view=False,
                )

        @skip_if_lt_x_gpu(2)
        def test_ddp_apply_optim_in_backward_ignored_params(self):
            torch.cuda.set_device(self.rank)
            for init_before in [True, False]:
                with self.subTest(init_before=init_before):
                    torch.manual_seed(self.rank)
                    torch.cuda.manual_seed(self.rank)
                    model = TwoLinLayerNet()
                    # Parameters to ignore are in the format {module_name}.{param_name}
                    params_to_ignore = ["a.weight"]
                    torch.nn.parallel.DistributedDataParallel._set_params_and_buffers_to_ignore_for_model(
                        model, params_to_ignore
                    )
                    if init_before:
                        _apply_optimizer_in_backward(
                            optimizer_class=torch.optim.SGD,
                            params=model.parameters(),
                            optimizer_kwargs={"lr": 0.03},
                        )
                    net = torch.nn.parallel.DistributedDataParallel(
                        model.cuda(self.rank),
                        device_ids=[self.rank],
                    )
                    if not init_before:
                        _apply_optimizer_in_backward(
                            optimizer_class=torch.optim.SGD,
                            params=model.parameters(),
                            optimizer_kwargs={"lr": 0.03},
                        )
                    inp = torch.randn(1, 10)
                    a, b = net(inp)
                    (a.transpose(0, 1) @ b).sum().backward()
                    # a.weight did not go through allreduce, so optimizer acted on local
                    # gradient, which should be different across ranks. Remaining params
                    # should be equal.
                    models = [None for _ in range(dist.get_world_size())]
                    dist.all_gather_object(models, model)
                    rank0_model, remainder = models[0], models[1:]
                    for m in remainder:
                        self.assertNotEqual(rank0_model.a.weight, m.a.weight)
                        self.assertEqual(
                            list(rank0_model.b.parameters()), list(m.b.parameters())
                        )
                        self.assertEqual(rank0_model.a.bias, m.a.bias)

        def _get_fp16_config(self) -> _MixedPrecision:
            return _MixedPrecision(
                param_dtype=torch.float16,
                reduce_dtype=torch.float16,
                buffer_dtype=torch.float16,
            )

        @skip_if_lt_x_gpu(2)
        def test_ddp_native_mixed_precision_ignored_params(self):
            rank = self.rank
            torch.manual_seed(rank)
            torch.cuda.manual_seed(rank)
            torch.cuda.set_device(rank)
            model = TwoLinLayerNet()
            model.register_buffer("buffer", torch.ones(5))
            # Parameters to ignore are in the format {module_name}.{param_name}
            to_ignore = ["a.weight", "buffer"]
            torch.nn.parallel.DistributedDataParallel._set_params_and_buffers_to_ignore_for_model(
                model, to_ignore,
            )
            mp_config = self._get_fp16_config()
            net = torch.nn.parallel.DistributedDataParallel(
                model.to(rank),
                device_ids=[rank],
                mixed_precision=mp_config,
                gradient_as_bucket_view=True,
            )
            to_ignore = [f"module.{name}" for name in to_ignore]
            expected_ignored = len(to_ignore)
            n_ignored = 0
            # ignored params should not have _mp_param or _fp_param fields.
            for (n, p) in itertools.chain(net.named_parameters(), net.named_buffers()):
                if n in to_ignore:
                    n_ignored += 1
                    self.assertFalse(hasattr(p, '_mp_param'))
                    self.assertFalse(hasattr(p, '_fp_param'))
                else:
                    self.assertEqual(mp_config.param_dtype, p._mp_param.dtype)
                    self.assertEqual(torch.float32, p._fp_param.dtype)

            self.assertEqual(expected_ignored, n_ignored)

        def _test_ddp_native_mixed_precision(
            self, gradient_as_bucket_view, set_grad_to_none
        ):
            rank = self.rank
            torch.manual_seed(rank)
            torch.cuda.manual_seed(rank)
            torch.cuda.set_device(rank)
            inp = torch.randn(10, 1)
            mp_config = self._get_fp16_config()

            class MyModel(torch.nn.Module):
                def __init__(self) -> None:
                    super().__init__()
                    self.m = torch.nn.Linear(1, 5)
                    self.register_buffer('buffer', torch.randn(1, 2))
                    self.p = torch.nn.Parameter(
                        torch.randn(10, 5), requires_grad=False
                    )

                def forward(self_, x):  # noqa: B902
                    params = self_.m.parameters()
                    for p in params:
                        self.assertEqual(mp_config.param_dtype, p.dtype)

                    self.assertEqual(self_.buffer.dtype, mp_config.buffer_dtype)

                    self.assertEqual(mp_config.param_dtype, x.dtype)
                    return self_.m(x) + self_.p

            m = MyModel()

            net = torch.nn.parallel.DistributedDataParallel(
                m.to(rank),
                device_ids=[rank],
                mixed_precision=mp_config,
                gradient_as_bucket_view=gradient_as_bucket_view,
            )
            # Buffers are casted in constructor.
            self.assertEqual(net.module.buffer.dtype, mp_config.buffer_dtype)
            # Each param should have an mp_param in the lower precision, and
            # an fp_param in the higher precision.
            for p in net.parameters():
                self.assertEqual(mp_config.param_dtype, p._mp_param.dtype)
                self.assertEqual(torch.float32, p._fp_param.dtype)

            for _ in range(6):
                loss = net(inp).sum()
                loss.backward()
                # Verify gradient synchronization and params and grads are fp32.
                for n, param in net.named_parameters():
                    self.assertEqual(param.dtype, torch.float32)
                    if param.grad is None:
                        assert n == 'module.p'  # Only param that doesn't require grad
                    else:
                        self.assertEqual(param.grad.dtype, torch.float32)
                        tensor_list = [
                            torch.zeros_like(param.grad)
                            for _ in range(dist.get_world_size(net.process_group))
                        ]
                        dist.all_gather(tensor_list, param.grad)
                        g, rest = tensor_list[0], tensor_list[1:]
                        self.assertEqual(g.dtype, torch.float32)
                        for g_ in rest:
                            self.assertEqual(g_.dtype, torch.float32)
                            self.assertEqual(g, g_)
                net.zero_grad(set_to_none=set_grad_to_none)

        @skip_if_lt_x_gpu(2)
        def test_ddp_native_mixed_precision_no_grad_as_bucket_view_no_set_grad_none(self):
            self._test_ddp_native_mixed_precision(
                gradient_as_bucket_view=False,
                set_grad_to_none=False,
            )

        @skip_if_lt_x_gpu(2)
        def test_ddp_native_mixed_precision_grad_as_bucket_view_no_set_grad_none(self):
            self._test_ddp_native_mixed_precision(
                gradient_as_bucket_view=True,
                set_grad_to_none=False,
            )

        @skip_if_lt_x_gpu(2)
        def test_ddp_native_mixed_precision_grad_as_bucket_view_set_grad_to_none(self):
            self._test_ddp_native_mixed_precision(
                gradient_as_bucket_view=True, set_grad_to_none=True
            )

        @skip_if_lt_x_gpu(2)
        def test_ddp_native_mixed_precision_no_grad_as_bucket_view_set_grad_to_none(self):
            self._test_ddp_native_mixed_precision(
                gradient_as_bucket_view=True, set_grad_to_none=True
            )

        def _test_ddp_hook_parity(self, state, hook, num_validated_iters=100):
            rank = self.rank
            m = torch.nn.Linear(1, 5)
            try:
                process_group = state.process_group
            except AttributeError:
                process_group = state

            net_with_hook = torch.nn.parallel.DistributedDataParallel(
                copy.deepcopy(m).to(rank),
                device_ids=[rank],
                process_group=process_group,
            )
            net_with_hook.register_comm_hook(state=state, hook=hook)
            net_without_hook = torch.nn.parallel.DistributedDataParallel(
                copy.deepcopy(m).to(rank),
                device_ids=[rank],
                process_group=process_group,
            )
            for i in range(100):
                # Clear gradients manually.
                for g in [
                    net_without_hook.module.weight.grad,
                    net_with_hook.module.weight.grad,
                ]:
                    if g is not None:
                        g.requires_grad_(False)
                        g.zero_()
                # Forward + BW
                batch = torch.tensor([rank]).float().cuda(rank)
                loss = net_without_hook(batch).sum()
                loss.backward()
                # For each worker, the gradient on the weight should be worker_rank.
                grad = net_without_hook.module.weight.grad
                avg = grad.clone()
                expected_grad = (
                    sum(i for i in range(dist.get_world_size())) / dist.get_world_size()
                )
                loss_hook = net_with_hook(batch).sum()
                loss_hook.backward()
                grad_hook = net_with_hook.module.weight.grad
                avg_hook = grad_hook.clone()

                if i < num_validated_iters:
                    # Verify hook grad with expected.
                    self.assertEqual(
                        avg_hook[0, 0].item(),
                        expected_grad,
                        msg=f"Expected hook grad of {expected_grad} but got {avg_hook[0, 0]}",
                    )
                    # Verify hook grad with vanilla allreduce
                    self.assertEqual(
                        avg_hook[0, 0],
                        avg[0, 0],
                        msg=f"Expected hook grad to be close to allreduce {avg[0, 0]}, but got {avg_hook[0, 0]}",
                    )

        @skip_but_pass_in_sandcastle_if(
            BACKEND not in DistTestCases.backend_feature["cuda"],
            f"The {BACKEND} backend does not support DDP communication hook on CUDA devices",
        )
        @skip_if_lt_x_gpu(int(os.environ["WORLD_SIZE"]))
        def test_ddp_hook_parity_allreduce(self):
            self._test_ddp_hook_parity(state=None, hook=default.allreduce_hook)

        @skip_but_pass_in_sandcastle_if(
            BACKEND not in DistTestCases.backend_feature["cuda"],
            f"The {BACKEND} backend does not support DDP communication hook on CUDA devices",
        )
        @skip_if_lt_x_gpu(int(os.environ["WORLD_SIZE"]))
        def test_ddp_hook_parity_allreduce_process_group(self):
            # process_group is passed in to both DDP and comm. hook
            world_size = dist.get_world_size()
            rank_to_GPU = init_multigpu_helper(world_size, BACKEND)
            gpus = [rank_to_GPU[int(r)][0] for r in range(world_size)]
            process_group = torch.distributed.new_group(gpus)
            self._test_ddp_hook_parity(state=process_group, hook=default.allreduce_hook)

        @skip_but_pass_in_sandcastle_if(
            BACKEND not in DistTestCases.backend_feature["cuda"],
            f"The {BACKEND} backend does not support DDP communication hook on CUDA devices",
        )
        @skip_if_lt_x_gpu(int(os.environ["WORLD_SIZE"]))
        def test_ddp_hook_parity_powerSGD(self):
            for warm_start in [True, False]:
                powersgd_state = powerSGD.PowerSGDState(
                    process_group=None,
                    matrix_approximation_rank=1,
                    start_powerSGD_iter=2,
                    warm_start=warm_start,
                )
                self._test_ddp_hook_parity(
                    state=powersgd_state, hook=powerSGD.powerSGD_hook
                )

        @skip_but_pass_in_sandcastle_if(
            BACKEND not in DistTestCases.backend_feature["cuda"],
            f"The {BACKEND} backend does not support DDP communication hook on CUDA devices",
        )
        @skip_but_pass_in_sandcastle_if(
            NO_MULTIPROCESSING_SPAWN,
            "Disabled for environments that \
                         don't support multiprocessing with spawn start method",
        )
        @skip_if_lt_x_gpu(int(os.environ["WORLD_SIZE"]))
        def test_ddp_hook_parity_post_localSGD(self):
            # Although we start run local SGD at iteration 10, since we still use the global process group to run it,
            # the post-LocalSGD actually still allreduces gradients globally for the remaining iterations.
            state = post_localSGD.PostLocalSGDState(
                process_group=None, subgroup=dist.group.WORLD, start_localSGD_iter=10
            )
            self._test_ddp_hook_parity(
                state=state, hook=post_localSGD.post_localSGD_hook
            )
            # Only validate the warmup iterations before local SGD is applied,
            # because when `post_local_gradient_allreduce` is disabled, the gradients will not be synchronized at all.
            # Note that in practice a model averager has to be applied to run model averaging,
            # so local gradient averaging is not necessary.
            start_localSGD_iter = 10
            state = post_localSGD.PostLocalSGDState(
                process_group=None,
                subgroup=dist.group.WORLD,
                start_localSGD_iter=start_localSGD_iter,
                post_local_gradient_allreduce=False,
            )
            self._test_ddp_hook_parity(
                state=state,
                hook=post_localSGD.post_localSGD_hook,
                num_validated_iters=start_localSGD_iter,
            )

            # When `subgroup` is None, it is equivalent to the subgroup on the each node.
            # For this single-node test environment, the intra-node process group is equivalent to
            # the global process group.
            if self.world_size == dist.get_world_size():
                state = post_localSGD.PostLocalSGDState(
                    process_group=None, subgroup=None, start_localSGD_iter=10
                )
                self._test_ddp_hook_parity(
                    state=state, hook=post_localSGD.post_localSGD_hook
                )

            # Since we start local SGD later than the total number of 100 iterations,
            # no local SGD actually is executed, and we don't even need to provide a subgroup for this case.
            state = post_localSGD.PostLocalSGDState(
                process_group=None, subgroup=None, start_localSGD_iter=1000
            )
            self._test_ddp_hook_parity(
                state=state, hook=post_localSGD.post_localSGD_hook
            )

        def _prepare_single_device_module(
            self,
            rank,
            process_group,
            devices,
            device_ids,
            global_batch_size,
            gradient_as_bucket_view=False,
        ):
            model = Net()
            device = devices[0] if devices else torch.device("cuda:%d" % rank)
            ddp_model = DistributedDataParallel(
                copy.deepcopy(model).to(device),
                device_ids=device_ids,
                process_group=process_group,
                bucket_cap_mb=0.001,
                gradient_as_bucket_view=gradient_as_bucket_view,
            )

            model.to(device)

            input = torch.randn(global_batch_size, 2).to(device)
            target = torch.randn(global_batch_size, 4).to(device)

            return model, ddp_model, input, target

        def _prepare_cpu_module(
            self,
            process_group,
            global_batch_size,
            gradient_as_bucket_view=False,
        ):
            model = Net()
            ddp_model = DistributedDataParallel(
                copy.deepcopy(model),
                process_group=process_group,
                bucket_cap_mb=0.001,
                gradient_as_bucket_view=gradient_as_bucket_view,
            )
            input = torch.randn(global_batch_size, 2)
            target = torch.randn(global_batch_size, 4)
            return model, ddp_model, input, target

        def _test_accumulate_gradients_no_sync(
            self, num_iters=2, ddp_comm_hook=None, gradient_as_bucket_view=False
        ):
            """
            This is the recommended way to implement accumulate grads.
            If ``ddp_comm_hook`` input was specified, it will also register that hook
            to the ``ddp_model``. The hook fed into this function should not change
            the resulting gradients.
            """
            _group, group_id, rank = self._init_global_test()
            world_size = get_world_size()

            # FIXME: Add testing for gloo/CUDA
            if BACKEND == "mpi" or BACKEND == "gloo":
                global_batch_size = world_size
                local_batch_size = 1
                model, ddp_model, input, target = self._prepare_cpu_module(
                    group_id, global_batch_size, gradient_as_bucket_view
                )

            if BACKEND == "nccl":
                rank_to_GPU = init_multigpu_helper(dist.get_world_size(), BACKEND)
                int_devices = rank_to_GPU[rank][:1]
                devices = [torch.device("cuda:" + str(i)) for i in int_devices]
                global_batch_size = world_size
                local_batch_size = len(devices)
                model, ddp_model, input, target = self._prepare_single_device_module(
                    rank,
                    group_id,
                    devices,
                    devices,
                    global_batch_size,
                    gradient_as_bucket_view,
                )

            if ddp_comm_hook is not None:
                ddp_model.register_comm_hook(group_id, ddp_comm_hook)

            def step_model(model, input, target):
                model.train()
                output = model(input)
                loss = F.mse_loss(output, target.to(output.device))
                loss.backward()

            # ensure accumulate grads works with no_grad => no grads are accumulated.
            with torch.no_grad():
                with ddp_model.no_sync():
                    ddp_model.train()
                    ddp_model(input)

            # check two model parameters over num_iters iterations
            for iteration in range(num_iters):
                step_model(model, input, target)

                ddp_input = input[
                    rank * local_batch_size : (rank + 1) * local_batch_size
                ]
                ddp_target = target[
                    rank * local_batch_size : (rank + 1) * local_batch_size
                ]

                if iteration % 2 == 0:
                    # accumulate grads locally
                    with ddp_model.no_sync():
                        step_model(ddp_model, ddp_input, ddp_target)
                else:
                    # sync grads
                    step_model(ddp_model, ddp_input, ddp_target)

                for i, j in zip(model.parameters(), ddp_model.parameters()):
                    if not i.requires_grad:
                        continue
                    if iteration % 2 == 0:
                        self.assertNotEqual(i.grad, j.grad)
                    else:
                        self.assertEqual(i.grad, j.grad)

                # Shuffle the input so that DDP input is different
                torch.manual_seed(1337 + iteration)
                input = input[torch.randperm(global_batch_size)]

        @skip_but_pass_in_sandcastle_if(
            BACKEND != "mpi" and BACKEND != "nccl" and BACKEND != "gloo",
            "get_future is only supported on mpi, nccl and gloo",
        )
        @nccl_skip_if_lt_x_gpu(BACKEND, 2)
        def test_accumulate_gradients_no_sync(self):
            """
            Runs _test_accumulate_gradients_no_sync using default inputs
            """
            self._test_accumulate_gradients_no_sync()

        @skip_but_pass_in_sandcastle_if(
            BACKEND != "mpi" and BACKEND != "nccl" and BACKEND != "gloo",
            "get_future is only supported on mpi, nccl and gloo",
        )
        @nccl_skip_if_lt_x_gpu(BACKEND, 2)
        def test_accumulate_gradients_no_sync_grad_is_view(self):
            """
            Runs _test_accumulate_gradients_no_sync using default inputs
            """
            self._test_accumulate_gradients_no_sync(gradient_as_bucket_view=True)

        @skip_but_pass_in_sandcastle_if(
            BACKEND != "mpi" and BACKEND != "nccl" and BACKEND != "gloo",
            "get_future is only supported on mpi, nccl and gloo",
        )
        @nccl_skip_if_lt_x_gpu(BACKEND, 2)
        def test_accumulate_gradients_no_sync_allreduce_hook(self):
            """
            Runs multiple iterations on _test_accumulate_gradients_no_sync
            using allreduce hook and validates whether future result was properly
            passed as gradients in reducer.
            """

            world_size = get_world_size()

            def allreduce_hook(
                group_id: object, bucket: dist.GradBucket
            ) -> torch.futures.Future[torch.Tensor]:
                tensors = [bucket.buffer() / world_size]
                return (
                    group_id.allreduce(tensors)
                    .get_future()
                    .then(lambda fut: fut.value()[0])
                )

            self._test_accumulate_gradients_no_sync(
                num_iters=4, ddp_comm_hook=allreduce_hook
            )

        @skip_but_pass_in_sandcastle_if(
            BACKEND != "mpi" and BACKEND != "nccl" and BACKEND != "gloo",
            "get_future is only supported on mpi, nccl and gloo",
        )
        @nccl_skip_if_lt_x_gpu(BACKEND, 2)
        def test_accumulate_gradients_no_sync_allreduce_with_then_hook(self):
            """
            Runs multiple iterations on _test_accumulate_gradients_no_sync using allreduce
            hook that also uses then callbacks. In first then callback result is multiplied
            by 2, and the second callback divides the result by 2 * world_size. It validates
            whether final result was properly passed as gradients in reducer.
            """

            world_size = get_world_size()

            def allreduce_with_then_hook(
                group_id: object, bucket: dist.GradBucket
            ) -> torch.futures.Future[torch.Tensor]:
                fut = group_id.allreduce([bucket.buffer()]).get_future()

                def mult(fut):
                    # Multiply the result by 2.
                    return 2 * fut.wait()[0]

                def div(fut):
                    # Divide the result by 2 * world_size.
                    return fut.wait() / (2 * world_size)

                return fut.then(mult).then(div)

            self._test_accumulate_gradients_no_sync(
                num_iters=4, ddp_comm_hook=allreduce_with_then_hook
            )

        @skip_but_pass_in_sandcastle_if(
            BACKEND != "mpi" and BACKEND != "nccl" and BACKEND != "gloo",
            "get_future is only supported on mpi, nccl and gloo",
        )
        @nccl_skip_if_lt_x_gpu(BACKEND, 2)
        def test_get_future(self):
            def mult(fut):
                return [t * 3 for t in fut.wait()]

            def add(fut):
                return [t + 1 for t in fut.wait()]

            group, group_id, rank = self._init_global_test()
            input = _build_tensor(3, 2)
            if BACKEND == "nccl":
                rank_to_GPU = init_multigpu_helper(dist.get_world_size(), BACKEND)
                device_id = rank_to_GPU[rank][0]
                input = input.to(device_id)
            fut = group_id.allreduce([input]).get_future()
            res = fut.then(mult).then(add).wait()
            expected = _build_tensor(3, 2 * len(group) * 3 + 1)

            self.assertEqual(res[0], expected)

        @skip_but_pass_in_sandcastle_if(
            BACKEND not in DistTestCases.backend_feature["ddp"],
            f"The {BACKEND} backend does not support DistributedDataParallel",
        )
        @skip_if_no_gpu
        def test_DistributedDataParallel(self):
            _group, _group_id, rank = self._init_global_test()
            rank_to_GPU = init_multigpu_helper(dist.get_world_size(), BACKEND)
            gpus = list(rank_to_GPU[rank])

            for use_bucket_view, static_graph in itertools.product(
                (False, True), (False, True)
            ):
                self._test_DistributedDataParallel(
                    gpu_subset=gpus,
                    rank=rank,
                    gradient_as_bucket_view=use_bucket_view,
                    static_graph=static_graph,
                )

                # test set static graph twice
                self._test_DistributedDataParallel(
                    gpu_subset=gpus,
                    rank=rank,
                    gradient_as_bucket_view=use_bucket_view,
                    static_graph=static_graph,
                    set_static_graph_twice=True,
                )

                # test output_device
                self._test_DistributedDataParallel(
                    gpu_subset=gpus,
                    rank=rank,
                    output_device=torch.device("cuda"),
                    gradient_as_bucket_view=use_bucket_view,
                    static_graph=static_graph,
                )

                # test device_ids
                gpus_list = [torch.device("cuda:" + str(i)) for i in gpus]
                self._test_DistributedDataParallel(
                    gpu_subset=gpus_list,
                    rank=rank,
                    output_device=torch.device("cuda"),
                    gradient_as_bucket_view=use_bucket_view,
                    static_graph=static_graph,
                )

        def _test_DistributedDataParallel_with_amp(self, grad_is_view=False):
            torch.manual_seed(31415)
            # Creates model and optimizer in default precision
            model = copy.deepcopy(DDP_NET).cuda()
            optimizer = torch.optim.SGD(model.parameters(), lr=0.03)

            # Creates a GradScaler once at the beginning of training.
            scaler = GradScaler()

            ddp_model = nn.parallel.DistributedDataParallel(
                model, device_ids=[self.rank], gradient_as_bucket_view=grad_is_view
            )

            input = torch.randn(dist.get_world_size() * 2, 2).cuda()
            target = torch.randn(dist.get_world_size() * 2, 4).cuda()
            loss_fn = nn.MSELoss()

            # verify grads are none before training
            for p in ddp_model.parameters():
                self.assertTrue(p is not None)
                self.assertTrue(p.grad is None)

            for idx in range(20):
                optimizer.zero_grad()
                # Runs the forward pass with autocasting.
                with autocast():
                    output = ddp_model(input)
                    loss = loss_fn(output, target)

                # Scales loss.  Calls backward() on scaled loss to create scaled gradients.
                # Backward passes under autocast are not recommended.
                # Backward ops run in the same dtype autocast chose for corresponding forward ops.
                scaler.scale(loss).backward()

                # verify grads are not none and are valid during training
                for p in ddp_model.parameters():
                    if p.requires_grad:
                        self.assertTrue(p.grad is not None)
                        self.assertFalse(p.grad.isnan().any())
                        self.assertFalse(p.grad.isinf().any())

                # scaler.step() first unscales the gradients of the optimizer's assigned params.
                # If these gradients do not contain infs or NaNs, optimizer.step() is then called,
                # otherwise, optimizer.step() is skipped.
                scaler.step(optimizer)

                # Updates the scale for next iteration.
                scaler.update()

                # Shuffle the input so that DDP input is different
                torch.manual_seed(1337 + idx)
                input = input[torch.randperm(dist.get_world_size() * 2)]

            return ddp_model

        @skip_but_pass_in_sandcastle_if(
            BACKEND not in DistTestCases.backend_feature["ddp"],
            f"The {BACKEND} backend does not support DistributedDataParallel",
        )
        @skip_if_no_gpu
        def test_DistributedDataParallel_with_amp_and_grad_is_view(self):
            torch.cuda.set_device(self.rank)
            ddp_model_grad_not_view = self._test_DistributedDataParallel_with_amp(
                grad_is_view=False
            )
            ddp_model_grad_is_view = self._test_DistributedDataParallel_with_amp(
                grad_is_view=True
            )
            for i, j in zip(
                ddp_model_grad_not_view.parameters(),
                ddp_model_grad_is_view.parameters(),
            ):
                self.assertEqual(i, j)

        def _test_DistributedDataParallel_SyncBatchNorm(
            self,
            gpu_subset,
            rank,
            local_bs,
            global_bs,
            offset,
            output_device=None,
            affine=True,
        ):
            # Run a simple end to end DDP model, use result of single node model
            # as baseline

            # cpu training setup
            model = BN_NET if affine else BN_NET_NO_AFFINE

            # single gpu training setup
            model_gpu = copy.deepcopy(model)
            model_gpu.cuda(gpu_subset[0])

            # DDP training setup
            model_DDP = nn.SyncBatchNorm.convert_sync_batchnorm(copy.deepcopy(model))
            model_DDP.cuda(gpu_subset[0])
            model_DDP = nn.parallel.DistributedDataParallel(
                model_DDP, device_ids=gpu_subset
            )

            # test serializable/unserializable
            with tempfile.NamedTemporaryFile() as tmp:
                if sys.platform == "win32":
                    torch.save(model_DDP, tmp)
                    tmp.seek(0)
                    # weights_only=False as this is legacy code that saves the model
                    model_DDP = torch.load(tmp, weights_only=False)
                else:
                    torch.save(model_DDP, tmp.name)
                    # weights_only=False as this is legacy code that saves the model
                    model_DDP = torch.load(tmp.name, weights_only=False)

            # data initialization
            input_cpu = torch.randn(global_bs, 2)
            target = torch.randn(global_bs, 4)
            loss = nn.MSELoss()

            # check two model parameters over 5 iterations
            self._test_DDP_niter(
                model_gpu,
                model_DDP,
                input_cpu.cuda(gpu_subset[0]),
                target.cuda(gpu_subset[0]),
                loss,
                local_bs,
                rank,
                global_bs,
                True,
                offset,
                dist.get_world_size(),
                5 if affine else 2,
            )
            self._barrier()

        def _test_post_localSGD_optimizer_parity(self, create_averager, grad_is_view):
            learning_rate = 0.03

            net = torch.nn.parallel.DistributedDataParallel(
                copy.deepcopy(DDP_NET).cuda(),
                device_ids=[self.rank],
                gradient_as_bucket_view=grad_is_view,
            )
            averager = create_averager()
            opt = torch.optim.SGD(net.parameters(), lr=learning_rate)

            net_using_post_localSGD_opt = torch.nn.parallel.DistributedDataParallel(
                copy.deepcopy(DDP_NET).cuda(),
                device_ids=[self.rank],
                gradient_as_bucket_view=grad_is_view,
            )
            # Process group cannot be pickled in some environments,
            # so cannot deep copy an averager. See:
            # https://github.com/pytorch/pytorch/pull/74737#pullrequestreview-922487496
            averager2 = create_averager()
            post_localSGD_opt = self._create_post_localSGD_optimizer(
                net_using_post_localSGD_opt, learning_rate, averager2
            )

            input = torch.randn(dist.get_world_size() * 2, 2).cuda()
            target = torch.randn(dist.get_world_size() * 2, 4).cuda()
            loss_fn = nn.MSELoss()

            for _ in range(20):
                self._perform_a_train_step(opt, net, loss_fn, input, target)
                averager.average_parameters(net.parameters())

                self._perform_a_train_step(
                    post_localSGD_opt,
                    net_using_post_localSGD_opt,
                    loss_fn,
                    input,
                    target,
                )
                for p1, p2 in zip(
                    net.parameters(), net_using_post_localSGD_opt.parameters()
                ):
                    self.assertEqual(p1.data, p2.data)

            # Also check if the built-in step counters are the same to prevent a bug like #74737.
            self.assertEqual(averager.step, averager2.step)

        def _create_periodic_model_averager(self):
            return averagers.PeriodicModelAverager(period=4, warmup_steps=10)

        def _create_post_localSGD_optimizer(self, net, learning_rate, averager):
            return post_localSGD_optimizer.PostLocalSGDOptimizer(
                optim=torch.optim.SGD(net.parameters(), lr=learning_rate),
                averager=averager,
            )

        def _perform_a_train_step(self, optimizer, net, loss_fn, input, target):
            optimizer.zero_grad()
            output = net(input)
            loss = loss_fn(output, target)
            loss.backward()
            optimizer.step()

        def _test_post_localSGD_optimizer_step_reload(
            self, create_averager, chkpt_file
        ):
            learning_rate = 0.03

            net_using_post_localSGD_opt = torch.nn.parallel.DistributedDataParallel(
                copy.deepcopy(DDP_NET).cuda(), device_ids=[self.rank]
            )

            averager = create_averager()
            post_localSGD_opt = self._create_post_localSGD_optimizer(
                net_using_post_localSGD_opt, learning_rate, averager
            )

            averager2 = create_averager()
            dummy_post_localSGD_opt = self._create_post_localSGD_optimizer(
                net_using_post_localSGD_opt, learning_rate, averager2
            )

            input = torch.randn(dist.get_world_size() * 2, 2).cuda()
            target = torch.randn(dist.get_world_size() * 2, 4).cuda()
            loss_fn = nn.MSELoss()

            for _ in range(20):
                self._perform_a_train_step(
                    post_localSGD_opt,
                    net_using_post_localSGD_opt,
                    loss_fn,
                    input,
                    target,
                )

            if self.rank == 0:
                torch.save(
                    {"optimizer_state_dict": post_localSGD_opt.state_dict()}, chkpt_file
                )

            dist.barrier()
            map_location = {"cuda:%d" % 0: "cuda:%d" % self.rank}
            checkpoint = torch.load(chkpt_file, map_location=map_location)
            dummy_post_localSGD_opt.load_state_dict(checkpoint["optimizer_state_dict"])

            # Check that we didn't hit the trivial case
            self.assertNotEqual(averager2.step, 0)
            # Check if dummy averager was initialized to a correct value
            self.assertEqual(averager.step, averager2.step)

            # Remove 'step' entry from a checkpoint.
            # And make sure it is not in the state dictionary
            del checkpoint["optimizer_state_dict"]["step"]
            self.assertNotIn("step", checkpoint["optimizer_state_dict"])

            # Check if checkpoint without a 'step' entry invokes a warning
            with self.assertWarnsRegex(
                expected_warning=UserWarning,
                expected_regex="Loaded state dict does not contain a step counter for an averager. "
                "Setting step counter to 0.",
            ):
                dummy_post_localSGD_opt.load_state_dict(
                    checkpoint["optimizer_state_dict"]
                )

            self.assertEqual(averager2.step, 0)

        @skip_if_lt_x_gpu(2)
        @skip_but_pass_in_sandcastle_if(
            BACKEND not in DistTestCases.backend_feature["ddp"],
            f"The {BACKEND} backend does not support DistributedDataParallel",
        )
        def test_post_localSGD_optimizer_parity(self):
            torch.cuda.set_device(self.rank)
            self._test_post_localSGD_optimizer_parity(
                self._create_periodic_model_averager,
                grad_is_view=False,
            )

        @skip_if_lt_x_gpu(2)
        @skip_but_pass_in_sandcastle_if(
            BACKEND not in DistTestCases.backend_feature["ddp"],
            f"The {BACKEND} backend does not support DistributedDataParallel",
        )
        def test_post_localSGD_optimizer_parity_grad_is_view(self):
            torch.cuda.set_device(self.rank)
            self._test_post_localSGD_optimizer_parity(
                self._create_periodic_model_averager,
                grad_is_view=True,
            )

        def _create_hierarchical_model_averager(self):
            period_group_size_dict = OrderedDict([(2, 2), (4, dist.get_world_size())])
            return hierarchicalSGD.HierarchicalModelAverager(
                period_group_size_dict=period_group_size_dict, warmup_steps=4
            )

        @skip_if_lt_x_gpu(4)
        @skip_if_odd_worldsize
        @skip_but_pass_in_sandcastle_if(
            BACKEND not in DistTestCases.backend_feature["ddp"],
            f"The {BACKEND} backend does not support DistributedDataParallel",
        )
        def test_post_localSGD_optimizer_parity_with_hierarchical_sgd(self):
            torch.cuda.set_device(self.rank)
            self._test_post_localSGD_optimizer_parity(
                self._create_hierarchical_model_averager,
                grad_is_view=False,
            )

        @skip_if_lt_x_gpu(4)
        @skip_if_odd_worldsize
        @skip_but_pass_in_sandcastle_if(
            BACKEND not in DistTestCases.backend_feature["ddp"],
            f"The {BACKEND} backend does not support DistributedDataParallel",
        )
        def test_post_localSGD_optimizer_parity_with_hierarchical_sgd_grad_is_view(
            self,
        ):
            torch.cuda.set_device(self.rank)
            self._test_post_localSGD_optimizer_parity(
                self._create_hierarchical_model_averager,
                grad_is_view=True,
            )

        @skip_if_lt_x_gpu(2)
        @skip_but_pass_in_sandcastle_if(
            BACKEND not in DistTestCases.backend_feature["ddp"],
            f"The {BACKEND} backend does not support DistributedDataParallel",
        )
        def test_post_localSGD_optimizer_step_reload(self):
            torch.cuda.set_device(self.rank)
            with _rank_temp_file() as tmp_file:
                self._test_post_localSGD_optimizer_step_reload(
                    self._create_periodic_model_averager, tmp_file
                )

        @skip_but_pass_in_sandcastle_if(
            BACKEND not in DistTestCases.backend_feature["ddp"],
            f"The {BACKEND} backend does not support DistributedDataParallel",
        )
        @skip_if_no_gpu
        def test_DistributedDataParallel_SyncBatchNorm_Channels_Last(self):
            self._test_DistributedDataParallel_SyncBatchNorm_with_memory_format(
                torch.channels_last
            )
            self._test_DistributedDataParallel_SyncBatchNorm_with_memory_format(
                torch.channels_last_3d
            )

        def _test_DistributedDataParallel_SyncBatchNorm_with_memory_format(
            self, memory_format
        ):
            _group, _group_id, rank = self._init_global_test()
            num_processes = dist.get_world_size()
            local_bs = 2
            bs_offset = int(rank * 2)
            global_bs = int(num_processes * 2)

            model = ONLY_SBN_NET
            model_gpu = copy.deepcopy(model).cuda(rank)
            model_DDP = nn.parallel.DistributedDataParallel(
                model_gpu, device_ids=[rank]
            )

            shapes = [global_bs, 2, 4, 4] + (
                [] if memory_format is torch.channels_last else [4]
            )

            input_gpu = (
                torch.randn(*shapes, dtype=torch.float)
                .cuda(rank)
                .to(memory_format=memory_format)
            )
            target_gpu = (
                torch.randn(*shapes, dtype=torch.float)
                .cuda(rank)
                .to(memory_format=memory_format)
            )
            loss = nn.MSELoss()

            # check two model parameters over 5 iterations
            self._test_DDP_niter(
                model_gpu,
                model_DDP,
                input_gpu,
                target_gpu,
                loss,
                local_bs,
                rank,
                global_bs,
                True,
                bs_offset,
                dist.get_world_size(),
                memory_format=memory_format,
            )
            self._barrier()

        @skip_but_pass_in_sandcastle_if(
            BACKEND not in DistTestCases.backend_feature["ddp"],
            f"The {BACKEND} backend does not support DistributedDataParallel",
        )
        @skip_if_no_gpu
        def test_DistributedDataParallel_SyncBatchNorm(self):
            _group, _group_id, rank = self._init_global_test()
            world_size = dist.get_world_size()
            # DDP does not support replicating BN layers within a process, hence
            # testing with one module replica per process
            gpus = [rank]

            local_bs = 2
            bs_offset = int(rank * 2)
            global_bs = int(world_size * 2)

            self._test_DistributedDataParallel_SyncBatchNorm(
                gpu_subset=gpus,
                rank=rank,
                local_bs=local_bs,
                global_bs=global_bs,
                offset=bs_offset,
            )

            # test output_device
            self._test_DistributedDataParallel_SyncBatchNorm(
                gpu_subset=gpus,
                rank=rank,
                local_bs=local_bs,
                global_bs=global_bs,
                offset=bs_offset,
                output_device=torch.device("cuda"),
            )

            # test device_ids
            gpus = [torch.device("cuda:" + str(i)) for i in gpus]
            self._test_DistributedDataParallel_SyncBatchNorm(
                gpu_subset=gpus,
                rank=rank,
                local_bs=local_bs,
                global_bs=global_bs,
                offset=bs_offset,
                output_device=torch.device("cuda"),
            )

        @skip_but_pass_in_sandcastle_if(
            BACKEND not in DistTestCases.backend_feature["ddp"],
            f"The {BACKEND} backend does not support DistributedDataParallel",
        )
        @skip_if_no_gpu
        def test_DistributedDataParallel_SyncBatchNorm_No_Affine(self):
            _group, _group_id, rank = self._init_global_test()
            world_size = dist.get_world_size()
            # DDP does not support replicating BN layers within a process, hence
            # testing with one module replica per process
            gpus = [rank]

            local_bs = 2
            bs_offset = int(rank * 2)
            global_bs = int(world_size * 2)

            self._test_DistributedDataParallel_SyncBatchNorm(
                gpu_subset=gpus,
                rank=rank,
                local_bs=local_bs,
                global_bs=global_bs,
                offset=bs_offset,
                affine=False,
            )

        @skip_but_pass_in_sandcastle_if(
            BACKEND not in DistTestCases.backend_feature["ddp"],
            f"The {BACKEND} backend does not support DistributedDataParallel",
        )
        @skip_if_no_gpu
        def test_DistributedDataParallel_SyncBatchNorm_2D_Input(self):
            _group, _group_id, rank = self._init_global_test()
            # DDP does not support replicating BN layers within a process, hence
            # testing with one module replica per process
            gpus = [rank]

            model = nn.BatchNorm1d(2)

            # single gpu training setup
            model_gpu = copy.deepcopy(model)
            model_gpu.cuda(gpus[0])

            # DDP training setup
            model_DDP = nn.SyncBatchNorm.convert_sync_batchnorm(copy.deepcopy(model))
            model_DDP.cuda(gpus[0])
            model_DDP = nn.parallel.DistributedDataParallel(model_DDP, device_ids=gpus)

            local_bs = len(gpus) * 2
            global_bs = dist.get_world_size() * local_bs
            input_cpu = torch.randn(global_bs, 2)
            target = torch.randn(global_bs, 2)
            loss = nn.MSELoss()

            # disabling cudnn.
            # SyncBatchNorm goes through native_batch_norm kernel, this avoids the
            # numerical issue created by the divergent code path.
            with torch.backends.cudnn.flags(False):
                # check two model parameters over 5 iterations
                self._test_DDP_niter(
                    model_gpu,
                    model_DDP,
                    input_cpu.cuda(gpus[0]),
                    target.cuda(gpus[0]),
                    loss,
                    local_bs,
                    rank,
                    global_bs,
                    True,
                )
                self._barrier()

        @skip_but_pass_in_sandcastle_if(
            BACKEND not in DistTestCases.backend_feature["ddp"],
            f"The {BACKEND} backend does not support DistributedDataParallel",
        )
        @skip_if_no_gpu
        @require_world_size(2)
        def test_DistributedDataParallel_SyncBatchNorm_Single_Input_Per_Process(self):
            _group, _group_id, rank = self._init_global_test()
            # DDP does not support replicating BN layers within a process, hence
            # testing with one module replica per process
            gpus = [rank]

            model = nn.BatchNorm1d(2)

            # single gpu training setup
            model_gpu = copy.deepcopy(model)
            model_gpu.cuda(gpus[0])

            # DDP training setup
            model_DDP = nn.SyncBatchNorm.convert_sync_batchnorm(copy.deepcopy(model))
            model_DDP.cuda(gpus[0])
            model_DDP = nn.parallel.DistributedDataParallel(model_DDP, device_ids=gpus)

            local_bs = 1
            global_bs = dist.get_world_size()
            input_cpu = torch.randn(global_bs, 2)
            target = torch.randn(global_bs, 2)
            loss = nn.MSELoss()

            # disabling cudnn.
            # SyncBatchNorm goes through native_batch_norm kernel, this avoids the
            # numerical issue created by the divergent code path.
            with torch.backends.cudnn.flags(False):
                # check two model parameters over 5 iterations
                self._test_DDP_niter(
                    model_gpu,
                    model_DDP,
                    input_cpu.cuda(gpus[0]),
                    target.cuda(gpus[0]),
                    loss,
                    local_bs,
                    rank,
                    global_bs,
                    True,
                )
                self._barrier()

        @skip_but_pass_in_sandcastle_if(
            BACKEND not in DistTestCases.backend_feature["ddp"],
            f"The {BACKEND} backend does not support DistributedDataParallel",
        )
        @skip_if_no_gpu
        def test_DistributedDataParallel_SyncBatchNorm_Diff_Input_Sizes_Running_Value(
            self,
        ):
            _group, _group_id, rank = self._init_global_test()
            model = nn.parallel.DistributedDataParallel(
                ONLY_SBN_NET.cuda(rank), device_ids=[rank]
            )

            input_var = []
            for i in range(dist.get_world_size()):
                input_var_rank = torch.cat(
                    [
                        torch.ones(2, 1, 10 ** (i + 1)) * (0.1 ** (i - 1)),
                        torch.ones(2, 1, 10 ** (i + 1)) * (0.3 ** (i - 1)),
                    ],
                    dim=1,
                )
                input_var.append(input_var_rank)

            all_input_var = torch.cat(
                [
                    x.permute(1, 0, 2).contiguous().view(ONLY_SBN_NET.num_features, -1)
                    for x in input_var
                ],
                dim=1,
            ).cuda(rank)

            for i in range(100):
                y = model(input_var[rank].cuda(rank))
                y.mean().backward()

            running_mean, running_var = (
                model.module.running_mean,
                model.module.running_var,
            )
            torch.testing.assert_close(running_mean, all_input_var.mean(1))
            torch.testing.assert_close(running_var, all_input_var.var(1))

        @skip_but_pass_in_sandcastle_if(
            BACKEND not in DistTestCases.backend_feature["ddp"],
            f"The {BACKEND} backend does not support DistributedDataParallel",
        )
        @skip_if_no_gpu
        def test_DistributedDataParallel_SyncBatchNorm_Diff_Input_Sizes_gradient(self):
            _group, _group_id, rank = self._init_global_test()
            # only do single GPU per process
            gpus = [rank]

            # cpu training setup
            num_processes = dist.get_world_size()
            local_bs = rank + 2
            bs_offset = int((rank + 3) * rank / 2)
            global_bs = int((num_processes + 3) * num_processes / 2)

            self._test_DistributedDataParallel_SyncBatchNorm(
                gpu_subset=gpus,
                rank=rank,
                local_bs=local_bs,
                global_bs=global_bs,
                offset=bs_offset,
            )

        @skip_but_pass_in_sandcastle_if(
            BACKEND not in DistTestCases.backend_feature["ddp"],
            f"The {BACKEND} backend does not support DistributedDataParallel",
        )
        @skip_if_no_gpu
        def test_DistributedDataParallel_SyncBatchNorm_half(self):
            _group, _group_id, rank = self._init_global_test()

            model = copy.deepcopy(BN_NET)
            model = model.half()
            model = torch.nn.SyncBatchNorm.convert_sync_batchnorm(model)
            model = nn.parallel.DistributedDataParallel(model.cuda(rank), device_ids=[rank])
            inp = torch.randn(2, 2, dtype=torch.float16, device=torch.device(rank))
            # Check that forward/backward do not error with dtype mismatch
            out = model(inp)
            self.assertEqual(out.dtype, torch.float16)
            out.sum().backward()
            for param in model.parameters():
                self.assertEqual(param.grad.dtype, torch.float16)

        def _test_ddp_logging_data(self, is_gpu):
            rank = dist.get_rank()
            model_DDP = copy.deepcopy(DDP_NET)
            if is_gpu:
                model_DDP = nn.parallel.DistributedDataParallel(
                    model_DDP.cuda(rank), device_ids=[rank]
                )
            else:
                model_DDP = nn.parallel.DistributedDataParallel(model_DDP)

            # dummy data initialization
            local_bs = 2
            batch_size, input, target, loss = self._prepare_dummy_data(local_bs)
            if is_gpu:
                input = input.cuda(rank)
                target = target.cuda(rank)

            model_DDP._set_ddp_runtime_logging_sample_rate(2)

            for idx in range(20):
                offset = rank * local_bs

                # DDP training, DDP scatters subsets of input to nodes/GPUs
                self._test_DDP_helper(
                    model_DDP,
                    input[offset : offset + local_bs],
                    target[offset : offset + local_bs],
                    loss,
                    1,
                )

                self._model_step_with_zero_grad(model_DDP)

                # Verify DDP logging data is sampled as expected
                # If it has ran more than 10 iterations and this is
                # the sampled iteration for measuring run time stats,
                # the run time stats for this idx-th iteration will not
                # be zeros.
                ddp_logging_data = model_DDP._get_ddp_logging_data()
                if idx > 0 and (idx < 10 or idx % 2 == 0):
                    self.assertGreaterEqual(
                        ddp_logging_data.get("forward_compute_time"), 1
                    )
                    self.assertGreaterEqual(
                        ddp_logging_data.get("backward_compute_time"), 1
                    )
                    self.assertGreaterEqual(
                        ddp_logging_data.get("backward_comm_time"), 1
                    )
                    self.assertGreaterEqual(
                        ddp_logging_data.get("backward_compute_time"),
                        ddp_logging_data.get("backward_compute_comm_overlap_time"),
                    )
                    self.assertGreaterEqual(
                        ddp_logging_data.get("backward_comm_time"),
                        ddp_logging_data.get("backward_compute_comm_overlap_time"),
                    )
                    self.assertEqual(ddp_logging_data.get("iteration"), idx)
                elif idx > 0:
                    # if the idx-th iteration is not sampled to set runtime stats,
                    # ddp_logging_data.iteration will not be updated to current
                    # iteration.
                    self.assertNotEqual(ddp_logging_data.get("iteration"), idx)

                # Shuffle the input so that DDP input is different
                input = input[torch.randperm(batch_size)]

            return model_DDP

        @skip_but_pass_in_sandcastle_if(
            BACKEND == "nccl", "nccl does not support DDP on CPU models"
        )
        def test_ddp_logging_data_cpu(self):
            def parse_env(var):
                return os.environ[var] if var in os.environ else "N/A"

            dist.set_debug_level(dist.DebugLevel.INFO)
            _, group_id, _ = self._init_global_test()
            model_DDP = self._test_ddp_logging_data(is_gpu=False)

            ddp_logging_data = model_DDP._get_ddp_logging_data()
            self.assertEqual(ddp_logging_data.get("world_size"), dist.get_world_size())
            self.assertEqual(ddp_logging_data.get("rank"), dist.get_rank())
            self.assertEqual(ddp_logging_data.get("module_name"), "Net")
            self.assertEqual(ddp_logging_data.get("device_ids"), "")
            # output_device is -1 in default if it is not set, e.g.
            # output_device of CPU training is -1.
            self.assertEqual(ddp_logging_data.get("output_device"), -1)
            self.assertEqual(ddp_logging_data.get("broadcast_buffers"), 1)
            self.assertEqual(ddp_logging_data.get("bucket_cap_bytes"), 25 * 1024 * 1024)
            self.assertEqual(ddp_logging_data.get("find_unused_parameters"), 0)
            self.assertEqual(ddp_logging_data.get("gradient_as_bucket_view"), 0)
            self.assertEqual(
                ddp_logging_data.get("backend_name"), dist.get_backend(group_id)
            )
            self.assertEqual(ddp_logging_data.get("iteration"), 18)
            params = list(model_DDP.parameters())
            num_params = 0
            param_size = 0
            params = list(filter(lambda parameter: parameter.requires_grad, params))
            for p in params:
                num_params += 1
                param_size += p.numel() * p.element_size()
            self.assertEqual(ddp_logging_data.get("dtypes"), "float")
            self.assertEqual(
                ddp_logging_data.get("total_parameter_size_bytes"), param_size
            )
            self.assertEqual(ddp_logging_data.get("num_parameter_tensors"), num_params)
            self.assertEqual(ddp_logging_data.get("bucket_sizes"), str(param_size))
            self.assertEqual(
                ddp_logging_data.get("master_port"), parse_env("MASTER_PORT")
            )
            self.assertEqual(
                ddp_logging_data.get("master_addr"), parse_env("MASTER_ADDR")
            )
            self.assertEqual(
                ddp_logging_data.get("torch_distributed_debug"),
                parse_env("TORCH_DISTRIBUTED_DEBUG"),
            )
            self.assertEqual(
                ddp_logging_data.get("cuda_visible_devices"),
                parse_env("CUDA_VISIBLE_DEVICES"),
            )
            if ddp_logging_data.get("backend_name") == "gloo":
                self.assertEqual(
                    ddp_logging_data.get("gloo_socket_ifname"),
                    parse_env("GLOO_SOCKET_IFNAME"),
                )
                self.assertEqual(
                    ddp_logging_data.get("gloo_device_transport"),
                    parse_env("GLOO_DEVICE_TRANSPORT"),
                )
                default_gloo_threads = 2
                self.assertEqual(
                    ddp_logging_data.get("gloo_num_threads"),
                    default_gloo_threads,
                )

            self.assertEqual(ddp_logging_data.get("nccl_socket_ifname"), None)
            self.assertEqual(ddp_logging_data.get("nccl_blocking_wait"), None)
            self.assertEqual(ddp_logging_data.get("nccl_async_error_handling"), None)
            self.assertEqual(ddp_logging_data.get("nccl_debug"), None)
            self.assertEqual(ddp_logging_data.get("nccl_nthreads"), None)
            self.assertEqual(ddp_logging_data.get("nccl_ib_timeout"), None)
            # test runtime logging fields
            # Note: DETAIL debug mode logs DDP logging data to stdout and
            # thus accesses std::map, which fills in a default value for the
            # type if it didn't exist.
            self.assertEqual(ddp_logging_data.get("unused_parameter_size", 0), 0)
            self.assertEqual(ddp_logging_data.get("has_rebuilt_buckets"), 1)
            self.assertEqual(
                ddp_logging_data.get("rebuilt_bucket_sizes"), str(param_size)
            )
            grad_ready_order = ddp_logging_data.get(
                "prev_iteration_grad_ready_order_indices"
            )
            expected_order = list(reversed([str(x) for x in range(3)]))
            self.assertEqual(grad_ready_order, ", ".join(expected_order))
            bucket_indices = ddp_logging_data.get("rebuilt_per_bucket_param_indices")
            self.assertEqual(bucket_indices, " ".join(expected_order))
            # It is hard to test accurate latency, but it can test whether the latency is
            # a valid value and in the expected range.
            self.assertGreaterEqual(ddp_logging_data.get("avg_forward_compute_time"), 1)
            self.assertGreaterEqual(
                ddp_logging_data.get("avg_backward_compute_time"), 1
            )
            self.assertGreaterEqual(ddp_logging_data.get("avg_backward_comm_time"), 1)
            self.assertGreaterEqual(
                ddp_logging_data.get("avg_backward_compute_time"),
                ddp_logging_data.get("avg_backward_compute_comm_overlap_time"),
            )
            self.assertGreaterEqual(
                ddp_logging_data.get("avg_backward_comm_time"),
                ddp_logging_data.get("avg_backward_compute_comm_overlap_time"),
            )
            # Test host-side times are roughly in the order that we expect
            fwd_host_side_time = ddp_logging_data.get("forward_compute_time_start")
            bwd_comp_start_host_side_time = ddp_logging_data.get(
                "backward_compute_time_start"
            )
            bwd_comp_end_host_side_time = ddp_logging_data.get(
                "backward_compute_time_end"
            )
            bwd_comm_start_host_side_time = ddp_logging_data.get(
                "backward_comm_time_start"
            )
            bwd_comm_end_host_side_time = ddp_logging_data.get("backward_comm_time_end")
            self.assertGreaterEqual(
                bwd_comm_end_host_side_time, bwd_comm_start_host_side_time
            )
            self.assertGreaterEqual(
                bwd_comm_start_host_side_time, bwd_comp_start_host_side_time
            )
            self.assertGreaterEqual(
                bwd_comp_end_host_side_time, bwd_comp_start_host_side_time
            )
            self.assertGreaterEqual(bwd_comp_start_host_side_time, fwd_host_side_time)

            # test larger net with mixed data types, verify multiple bucket sizes
            model = LargeNet()
            model.float()
            model.fc1.double()
            model_DDP = nn.parallel.DistributedDataParallel(model, bucket_cap_mb=1.5)
            ddp_logging_data = model_DDP._get_ddp_logging_data()
            params = list(model_DDP.parameters())
            self.assertEqual(
                ddp_logging_data.get("bucket_cap_bytes"), int(1.5 * 1024 * 1024)
            )
            bucket_sizes = [
                params[1].numel() * params[1].element_size(),
                params[0].numel() * params[0].element_size(),
            ]
            self.assertEqual(
                ddp_logging_data.get("bucket_sizes"),
                ", ".join(str(x) for x in bucket_sizes),
            )
            self.assertEqual(ddp_logging_data.get("dtypes"), "double, float")

        @skip_but_pass_in_sandcastle_if(
            BACKEND not in DistTestCases.backend_feature["ddp"],
            f"The {BACKEND} backend does not support DistributedDataParallel",
        )
        @skip_if_no_gpu
        def test_ddp_logging_data_gpu(self):
            _group, _group_id, rank = self._init_global_test()
            model_DDP = self._test_ddp_logging_data(is_gpu=True)
            ddp_logging_data = model_DDP._get_ddp_logging_data()
            self.assertEqual(ddp_logging_data.get("device_ids"), str(rank))
            self.assertEqual(ddp_logging_data.get("output_device"), rank)
            grad_ready_order = ddp_logging_data.get(
                "prev_iteration_grad_ready_order_indices"
            )
            expected_order = list(reversed([str(x) for x in range(3)]))
            self.assertEqual(grad_ready_order, ", ".join(expected_order))
            bucket_indices = ddp_logging_data.get("rebuilt_per_bucket_param_indices")
            self.assertEqual(bucket_indices, " ".join(expected_order))
            # test runtime logging fields
            # It is hard to test accurate latency, but it can test whether the latency is
            # a valid value and in the expected range.
            self.assertGreaterEqual(ddp_logging_data.get("avg_forward_compute_time"), 1)
            self.assertGreaterEqual(
                ddp_logging_data.get("avg_backward_compute_comm_overlap_time"), 1
            )
            self.assertGreaterEqual(
                ddp_logging_data.get("avg_backward_compute_time"),
                ddp_logging_data.get("avg_backward_compute_comm_overlap_time"),
            )
            self.assertGreaterEqual(
                ddp_logging_data.get("avg_backward_comm_time"),
                ddp_logging_data.get("avg_backward_compute_comm_overlap_time"),
            )
            # Test host-side times are roughly in the order that we expect
            fwd_host_side_time = ddp_logging_data.get("forward_compute_time_start")
            bwd_comp_start_host_side_time = ddp_logging_data.get(
                "backward_compute_time_start"
            )
            bwd_comp_end_host_side_time = ddp_logging_data.get(
                "backward_compute_time_end"
            )
            bwd_comm_start_host_side_time = ddp_logging_data.get(
                "backward_comm_time_start"
            )
            bwd_comm_end_host_side_time = ddp_logging_data.get("backward_comm_time_end")
            self.assertGreaterEqual(
                bwd_comm_end_host_side_time, bwd_comm_start_host_side_time
            )
            self.assertGreaterEqual(
                bwd_comm_start_host_side_time, bwd_comp_start_host_side_time
            )
            self.assertGreaterEqual(
                bwd_comp_end_host_side_time, bwd_comp_start_host_side_time
            )
            self.assertGreaterEqual(bwd_comp_start_host_side_time, fwd_host_side_time)

        @skip_but_pass_in_sandcastle_if(
            BACKEND == "nccl", "nccl does not support DDP on CPU models"
        )
        def test_static_graph_api_cpu(self):
            model_DDP = nn.parallel.DistributedDataParallel(DDP_NET)
            expected_err = "should be called before training loop starts"
            with self.assertRaisesRegex(RuntimeError, expected_err):
                local_bs = 2
                _batch_size, input, target, loss = self._prepare_dummy_data(local_bs)
                offset = dist.get_rank() * local_bs

                # DDP training, DDP scatters subsets of input to nodes/GPUs
                self._test_DDP_helper(
                    model_DDP,
                    input[offset : offset + local_bs],
                    target[offset : offset + local_bs],
                    loss,
                    1,
                )
                model_DDP._set_static_graph()

            # Verify error was logged in ddp_logging_data.
            verify_ddp_error_logged(model_DDP, expected_err)

        @skipIfNoTorchVision
        def test_SyncBatchNorm_process_group(self):
            # When adopting `convert_sync_batchnorm` to convert a `nn.modules`,
            # it need to recursively pass the `process_group` in the module when the `SyncBatchNorm`
            # is nested in a sub-module or sub-sub-module (e.g. resnet50 in torchvision.models).

            process_ids = 0
            process_group = torch.distributed.new_group([process_ids])
            res50_model = torchvision.models.resnet50()
            res50_model_sync = nn.SyncBatchNorm.convert_sync_batchnorm(
                copy.deepcopy(res50_model), process_group
            )
            process_group_sync = res50_model_sync.layer1[0].bn1.process_group
            self.assertEqual(process_group_sync, process_group)

        def _run_reduction_test(
            self, tensor, expected_tensor, op, reduction_fn=dist.all_reduce, dst=None
        ):
            if reduction_fn != dist.all_reduce and dst is None:
                raise ValueError(f"Reduction fn {reduction_fn} must specify dst!")
            if dst is not None:
                reduction_fn(tensor, dst, op)
                # Only destination rank tensor is expected to have final result.
                if dist.get_rank() == dst:
                    self.assertEqual(tensor, expected_tensor)
            else:
                reduction_fn(tensor, op)
                self.assertEqual(tensor, expected_tensor)

        @require_backend_is_available({"nccl"})
        @skip_if_lt_x_gpu(2)
        def test_nccl_backend_bool_allreduce(self):
            torch.cuda.set_device(self.rank)
            # Run all_reduce with PRODUCT
            element = self.rank % 2 == 0
            for op in [dist.ReduceOp.PRODUCT, dist.ReduceOp.MIN]:
                input_tensor = torch.tensor([element, element]).to(self.rank)
                self._run_reduction_test(
                    input_tensor, torch.tensor([False, False]).to(self.rank), op
                )
                # Ensure that all ranks contributing True (cast to 1) results in the
                # correct reduction.
                input_tensor = torch.tensor([True, True]).to(self.rank)
                expected_tensor = input_tensor.clone()
                self._run_reduction_test(input_tensor, expected_tensor, op)

            # Run all_reduce with SUM
            for op in [dist.ReduceOp.SUM, dist.ReduceOp.MAX]:
                input_tensor = torch.tensor([element, element]).to(self.rank)
                self._run_reduction_test(
                    input_tensor, torch.tensor([True, True]).to(self.rank), op
                )
            # TODO: NCCL backend does not work correctly for bitwise reduction ops
            # (see https://github.com/pytorch/pytorch/issues/41362). Add tests for
            # these once i