import torch
import torch.distributed as dist

from torch import nn
from torch.nn.parallel import DistributedDataParallel
from torch.testing._internal.dist_utils import INIT_METHOD_TEMPLATE, dist_init
from torch.testing._internal.distributed.rpc.rpc_agent_test_fixture import (
    RpcAgentTestFixture,
)
from torch.testing._internal.common_distributed import (
    requires_gloo,
    requires_nccl,
    skip_if_lt_x_gpu,
    skip_if_rocm,
)
from torch.distributed.pipeline.sync import Pipe

class PipeWithDDPTest(RpcAgentTestFixture):
    @property
    def world_size(self) -> int:
        return 2

    @skip_if_lt_x_gpu(4)
    @requires_nccl()
    @dist_init
    @skip_if_rocm
    def test_basic_nccl_ckpt_never(self):
        self._run_basic_test("nccl", "never")

    @skip_if_lt_x_gpu(4)
    @requires_nccl()
    @dist_init
    @skip_if_rocm
    def test_basic_nccl_ckpt_never_find_unused(self):
        self._run_basic_test("nccl", "never", find_unused_parameters=True)

    @skip_if_lt_x_gpu(4)
    @requires_nccl()
    @dist_init
    @skip_if_rocm
    def test_basic_nccl_ckpt_always(self):
        self._run_basic_test("nccl", "always", static_graph=True)

    @skip_if_lt_x_gpu(4)
    @requires_nccl()
    @dist_init
    @skip_if_rocm
    def test_basic_nccl_ckpt_except_last(self):
        self._run_basic_test("nccl", "except_last", static_graph=True)

    @skip_if_lt_x_gpu(4)
    @requires_gloo()
    @dist_init
    @skip_if_rocm
    def test_basic_gloo_ckpt_never(self):
        self._run_basic_test("gloo", "never")

    @skip_if_lt_x_gpu(4)
    @requires_gloo()
    @dist_init
    @skip_if_rocm
    def test_basic_gloo_ckpt_never_find_unused(self):
        self._run_basic_test("gloo", "never", find_unused_parameters=True)

    @skip_if_lt_x_gpu(4)
    @requires_gloo()
    @dist_init
    @skip_if_rocm
    def test_basic_gloo_ckpt_always(self):
        self._run_basic_test("gloo", "always", static_graph=True)

    @skip_if_lt_x_gpu(4)
    @requires_gloo()
    @dist_init
    @skip_if_rocm
    def test_basic_gloo_ckpt_except_last(self):
        self._run_basic_test("gloo", "except_last", static_graph=True)

    def _run_basic_test(self, backend, checkpoint, find_unused_parameters=False, static_graph=False):
        dist.init_process_group(
            backend=backend,
            init_method=INIT_METHOD_TEMPLATE.format(file_name=self.file_name),
            world_size=self.world_size,
            rank=self.rank,
        )

        # Use 4 GPUs, two replicas of a pipe across GPU 0 and 1 and another
        # pipe between GPU 2 and 3. Both replicas are replicated via DDP.
        fc1 = nn.Linear(16, 8, bias=False).cuda(2 * self.rank)

        class MyModule(nn.Module):
            def __init__(self, device):
                super(MyModule, self).__init__()
                self.fc2 = nn.Linear(8, 4, bias=False).cuda(device)
                self.fc3 = nn.Linear(4, 2, bias=False).cuda(device)

            def forward(self, inp):
                if find_unused_parameters:
                    return self.fc2(inp)
                else:
                    return self.fc3(self.fc2(inp))

        layer2 = MyModule(2 * self.rank + 1)
        model = nn.Sequential(
            fc1,
            layer2
        )
        model = Pipe(model, chunks=2, checkpoint=checkpoint)
        model = DistributedDataParallel(model, find_unused_parameters=find_unused_parameters)
        if static_graph:
            model._set_static_graph()

        # Ensure inputs are different across ranks to verify that gradient
        # sync indeed occurs.
        model_input = torch.rand(16, 16).cuda(2 * self.rank) * (self.rank + 1)
        out = model(model_input).local_value()
        out.sum().backward()

        # Run forward again for find_unused_parameters to trigger any potential errors.
        if find_unused_parameters:
            # Ensure inputs are different across ranks to verify that gradient
            # sync indeed occurs.
            unused_param_input = torch.rand(16, 16).cuda(2 * self.rank) * (self.rank + 1)
            model(unused_param_input).local_value().sum().backward()

        # Run a few more iterations of fwd + bwd to ensure gradient synchronization
        # occurs properly across iterations via delay_all_reduce/bucketized allreduce.
        for _ in range(3):
            model_input = torch.rand(16, 16).cuda(2 * self.rank) * (self.rank + 1)
            out = model(model_input).local_value()
            out.sum().backward()

        # Check grads
        output = [torch.empty_like(fc1.weight.grad), torch.empty_like(fc1.weight.grad)]
        dist.all_gather(output, fc1.weight.grad)
        self.assertEqual(output[0], output[1])

        output = [torch.empty_like(layer2.fc2.weight.grad), torch.empty_like(layer2.fc2.weight.grad)]
        dist.all_gather(output, layer2.fc2.weight.grad)
        self.assertEqual(output[0], output[1])

        if not find_unused_parameters:
            output = [torch.empty_like(layer2.fc3.weight.grad), torch.empty_like(layer2.fc3.weight.grad)]
            dist.all_gather(output, layer2.fc3.weight.grad)
            self.assertEqual(output[0], output[1])
