import torch
from torch import distributed as dist
from torch import nn

from ..base_module import BaseModule
from ...utils import DistriConfig


class DistriGroupNorm(BaseModule):
    def __init__(self, module: nn.GroupNorm, distri_config: DistriConfig):
        assert isinstance(module, nn.GroupNorm)
        super(DistriGroupNorm, self).__init__(module, distri_config)

    def forward(self, x: torch.Tensor) -> torch.Tensor:
        module = self.module
        assert isinstance(module, nn.GroupNorm)
        distri_config = self.distri_config

        if self.comm_manager is not None and self.comm_manager.handles is not None and self.idx is not None:
            if self.comm_manager.handles[self.idx] is not None:
                self.comm_manager.handles[self.idx].wait()
                self.comm_manager.handles[self.idx] = None

        assert x.ndim == 4
        n, c, h, w = x.shape
        num_groups = module.num_groups
        group_size = c // num_groups

        if distri_config.mode in ["stale_gn", "corrected_async_gn"]:
            if self.buffer_list is None:
                if self.comm_manager.buffer_list is None:
                    n, c, h, w = x.shape
                    self.idx = self.comm_manager.register_tensor(
                        shape=[2, n, num_groups, 1, 1, 1], torch_dtype=x.dtype, layer_type="gn"
                    )
                else:
                    self.buffer_list = self.comm_manager.get_buffer_list(self.idx)
            x = x.view([n, num_groups, group_size, h, w])
            x_mean = x.mean(dim=[2, 3, 4], keepdim=True)  # [1, num_groups, 1, 1, 1]
            x2_mean = (x**2).mean(dim=[2, 3, 4], keepdim=True)  # [1, num_groups, 1, 1, 1]
            slice_mean = torch.stack([x_mean, x2_mean], dim=0)

            if self.buffer_list is None:
                full_mean = slice_mean
            elif self.counter <= distri_config.warmup_steps:
                dist.all_gather(self.buffer_list, slice_mean, group=distri_config.batch_group, async_op=False)
                full_mean = sum(self.buffer_list) / distri_config.n_device_per_batch
            else:
                if distri_config.mode == "corrected_async_gn":
                    correction = slice_mean - self.buffer_list[distri_config.split_idx()]
                    full_mean = sum(self.buffer_list) / distri_config.n_device_per_batch + correction
                else:
                    new_buffer_list = [buffer for buffer in self.buffer_list]
                    new_buffer_list[distri_config.split_idx()] = slice_mean
                    full_mean = sum(new_buffer_list) / distri_config.n_device_per_batch
                self.comm_manager.enqueue(self.idx, slice_mean)

            full_x_mean, full_x2_mean = full_mean[0], full_mean[1]
            var = full_x2_mean - full_x_mean**2
            if distri_config.mode == "corrected_async_gn":
                slice_x_mean, slice_x2_mean = slice_mean[0], slice_mean[1]
                slice_var = slice_x2_mean - slice_x_mean**2
                var = torch.where(var < 0, slice_var, var)  # Correct negative variance

            num_elements = group_size * h * w
            var = var * (num_elements / (num_elements - 1))
            std = (var + module.eps).sqrt()
            output = (x - full_x_mean) / std
            output = output.view([n, c, h, w])
            if module.affine:
                output = output * module.weight.view([1, -1, 1, 1])
                output = output + module.bias.view([1, -1, 1, 1])
        else:
            if self.counter <= distri_config.warmup_steps or distri_config.mode in ["sync_gn", "full_sync"]:
                x = x.view([n, num_groups, group_size, h, w])
                x_mean = x.mean(dim=[2, 3, 4], keepdim=True)  # [1, num_groups, 1, 1, 1]
                x2_mean = (x**2).mean(dim=[2, 3, 4], keepdim=True)  # [1, num_groups, 1, 1, 1]
                mean = torch.stack([x_mean, x2_mean], dim=0)
                dist.all_reduce(mean, op=dist.ReduceOp.SUM, group=distri_config.batch_group)
                mean = mean / distri_config.n_device_per_batch
                x_mean = mean[0]
                x2_mean = mean[1]
                var = x2_mean - x_mean**2
                num_elements = group_size * h * w
                var = var * (num_elements / (num_elements - 1))
                std = (var + module.eps).sqrt()
                output = (x - x_mean) / std
                output = output.view([n, c, h, w])
                if module.affine:
                    output = output * module.weight.view([1, -1, 1, 1])
                    output = output + module.bias.view([1, -1, 1, 1])
            elif distri_config.mode in ["separate_gn", "no_sync"]:
                output = module(x)
            else:
                raise NotImplementedError
        self.counter += 1
        return output
