import torch
from torch import distributed as dist
from torch import nn
from torch.nn import functional as F

from ..base_module import BaseModule
from ...utils import DistriConfig


class DistriConv2dTP(BaseModule):
    def __init__(self, module: nn.Conv2d, distri_config: DistriConfig):
        super(DistriConv2dTP, self).__init__(module, distri_config)
        assert module.in_channels % distri_config.n_device_per_batch == 0

        sharded_module = nn.Conv2d(
            module.in_channels // distri_config.n_device_per_batch,
            module.out_channels,
            module.kernel_size,
            module.stride,
            module.padding,
            module.dilation,
            module.groups,
            module.bias is not None,
            module.padding_mode,
            device=module.weight.device,
            dtype=module.weight.dtype,
        )
        start_idx = distri_config.split_idx() * (module.in_channels // distri_config.n_device_per_batch)
        end_idx = (distri_config.split_idx() + 1) * (module.in_channels // distri_config.n_device_per_batch)
        sharded_module.weight.data.copy_(module.weight.data[:, start_idx:end_idx])
        if module.bias is not None:
            sharded_module.bias.data.copy_(module.bias.data)

        self.module = sharded_module
        del module

    def forward(self, x: torch.Tensor) -> torch.Tensor:
        distri_config = self.distri_config

        b, c, h, w = x.shape
        start_idx = distri_config.split_idx() * (c // distri_config.n_device_per_batch)
        end_idx = (distri_config.split_idx() + 1) * (c // distri_config.n_device_per_batch)
        output = F.conv2d(
            x[:, start_idx:end_idx],
            self.module.weight,
            bias=None,
            stride=self.module.stride,
            padding=self.module.padding,
            dilation=self.module.dilation,
            groups=self.module.groups,
        )
        dist.all_reduce(output, op=dist.ReduceOp.SUM, group=distri_config.batch_group, async_op=False)
        if self.module.bias is not None:
            output = output + self.module.bias.view(1, -1, 1, 1)

        self.counter += 1
        return output
