import torch
from torch import distributed as dist
from torch import nn
from torch.nn import functional as F

from ..base_module import BaseModule
from ...utils import DistriConfig


class DistriConv2dPP(BaseModule):
    def __init__(self, module: nn.Conv2d, distri_config: DistriConfig, is_first_layer: bool = False):
        super(DistriConv2dPP, self).__init__(module, distri_config)
        self.is_first_layer = is_first_layer

    def naive_forward(self, x: torch.Tensor) -> torch.Tensor:
        #  x: [B, C, H, W]
        output = self.module(x)
        return output

    def sliced_forward(self, x: torch.Tensor) -> torch.Tensor:
        config = self.distri_config
        b, c, h, w = x.shape
        assert h % config.n_device_per_batch == 0

        stride = self.module.stride[0]
        padding = self.module.padding[0]

        output_h = x.shape[2] // stride // config.n_device_per_batch
        idx = config.split_idx()
        h_begin = output_h * idx * stride - padding
        h_end = output_h * (idx + 1) * stride + padding
        final_padding = [padding, padding, 0, 0]
        if h_begin < 0:
            h_begin = 0
            final_padding[2] = padding
        if h_end > h:
            h_end = h
            final_padding[3] = padding
        sliced_input = x[:, :, h_begin:h_end, :]
        padded_input = F.pad(sliced_input, final_padding, mode="constant")
        return F.conv2d(padded_input, self.module.weight, self.module.bias, stride=stride, padding="valid")

    def forward(self, x: torch.Tensor, *args, **kwargs) -> torch.Tensor:
        distri_config = self.distri_config

        if self.comm_manager is not None and self.comm_manager.handles is not None and self.idx is not None:
            if self.comm_manager.handles[self.idx] is not None:
                self.comm_manager.handles[self.idx].wait()
                self.comm_manager.handles[self.idx] = None

        if distri_config.n_device_per_batch == 1:
            output = self.naive_forward(x)
        else:
            if self.is_first_layer:
                full_x = x
                output = self.sliced_forward(full_x)
            else:
                boundary_size = self.module.padding[0]
                if self.buffer_list is None:
                    if self.comm_manager.buffer_list is None:
                        self.idx = self.comm_manager.register_tensor(
                            shape=[2, x.shape[0], x.shape[1], boundary_size, x.shape[3]],
                            torch_dtype=x.dtype,
                            layer_type="conv2d",
                        )
                    else:
                        self.buffer_list = self.comm_manager.get_buffer_list(self.idx)
                if self.buffer_list is None:
                    output = self.naive_forward(x)
                else:

                    def create_padded_x():
                        if distri_config.split_idx() == 0:
                            concat_x = torch.cat([x, self.buffer_list[distri_config.split_idx() + 1][0]], dim=2)
                            padded_x = F.pad(concat_x, [0, 0, boundary_size, 0], mode="constant")
                        elif distri_config.split_idx() == distri_config.n_device_per_batch - 1:
                            concat_x = torch.cat([self.buffer_list[distri_config.split_idx() - 1][1], x], dim=2)
                            padded_x = F.pad(concat_x, [0, 0, 0, boundary_size], mode="constant")
                        else:
                            padded_x = torch.cat(
                                [
                                    self.buffer_list[distri_config.split_idx() - 1][1],
                                    x,
                                    self.buffer_list[distri_config.split_idx() + 1][0],
                                ],
                                dim=2,
                            )
                        return padded_x

                    boundary = torch.stack([x[:, :, :boundary_size, :], x[:, :, -boundary_size:, :]], dim=0)

                    if distri_config.mode == "full_sync" or self.counter <= distri_config.warmup_steps:
                        dist.all_gather(self.buffer_list, boundary, group=distri_config.batch_group, async_op=False)
                        padded_x = create_padded_x()
                        output = F.conv2d(
                            padded_x,
                            self.module.weight,
                            self.module.bias,
                            stride=self.module.stride[0],
                            padding=(0, self.module.padding[1]),
                        )
                    else:
                        padded_x = create_padded_x()
                        output = F.conv2d(
                            padded_x,
                            self.module.weight,
                            self.module.bias,
                            stride=self.module.stride[0],
                            padding=(0, self.module.padding[1]),
                        )
                        if distri_config.mode != "no_sync":
                            self.comm_manager.enqueue(self.idx, boundary)

        self.counter += 1
        return output
