import torch.nn as nn
import torch.nn.functional as F
from torch.nn.modules.module import Module
from torch.nn.parameter import Parameter
import torch
import itertools
from torch.nn.modules.batchnorm import _BatchNorm


class DSNormSync(Module):
    def __init__(self, num_features, eps=1e-5, momentum=0.1, affine=True, track_running_stats=True,
                 process_group=None):
        super(DSNormSync, self).__init__()
        self.num_features = num_features
        self.eps = eps
        self.momentum = momentum
        self.affine = affine
        self.track_running_stats = track_running_stats
        # 0 indicate source, 1 indicate target, 2 indicate mid
        self.domain_label = 0
        self.process_group = process_group

        if self.affine:
            self.weight = Parameter(torch.Tensor(num_features))
            self.bias = Parameter(torch.Tensor(num_features))
        else:
            self.register_parameter('weight', None)
            self.register_parameter('bias', None)

        if self.track_running_stats:
            self.register_buffer('running_mean_source', torch.zeros(num_features))
            self.register_buffer('running_mean_target', torch.zeros(num_features))
            self.register_buffer('running_var_source', torch.ones(num_features))
            self.register_buffer('running_var_target', torch.ones(num_features))
            self.register_buffer('num_batches_tracked', torch.tensor(0, dtype=torch.long))
            
            self.register_buffer('running_mean_mid', torch.zeros(num_features))
            self.register_buffer('running_var_mid', torch.ones(num_features))
        else:
            self.register_parameter('running_mean_source', None)
            self.register_parameter('running_mean_target', None)
            self.register_parameter('running_var_source', None)
            self.register_parameter('running_var_target', None)

            self.register_parameter('running_mean_mid', None)
            self.register_parameter('running_var_mid', None)
        self.reset_parameters()

    def reset_running_stats(self):
        if self.track_running_stats:
            self.running_mean_source.zero_()
            self.running_var_source.fill_(1)
            self.running_mean_target.zero_()
            self.running_var_target.fill_(1)
            self.num_batches_tracked.zero_()

            self.running_mean_mid.zero_()
            self.running_var_mid.fill_(1)

    def set_domain_label(self, domain_label):
        self.domain_label = domain_label

    def reset_parameters(self):
        self.reset_running_stats()
        if self.affine:
            nn.init.ones_(self.weight)
            nn.init.zeros_(self.bias)

    def _check_input_dim(self, input):
        return NotImplemented

    def _get_running_stats(self):
        if self.domain_label == 0:
            return self.running_mean_source, self.running_var_source
        elif self.domain_label == 1:
            return self.running_mean_target, self.running_var_target
        else:
            return self.running_mean_mid, self.running_var_mid

    def _sync_stats(self, mean, var, n):
        """同步所有设备上的均值和方差"""
        if self.process_group:
            # flatten tensors for all-reduce操作
            tensors = torch.cat([mean, var, n.new_tensor([n])], dim=0)
            
            # 在所有设备上进行求和
            torch.distributed.all_reduce(tensors, group=self.process_group)
            
            # 解析同步后的结果
            total_batch = tensors[-1].item()
            mean, var = torch.split(tensors[:-1], self.num_features)
            
            # 计算全局均值和方差
            mean /= total_batch
            var /= total_batch
            
            return mean, var, total_batch
        else:
            return mean, var, n

    def forward(self, input):
        self._check_input_dim(input)
        
        if self.momentum is None:
            exponential_average_factor = 0.0
        else:
            exponential_average_factor = self.momentum
            
        if self.training and self.track_running_stats:
            if self.num_batches_tracked is not None:
                self.num_batches_tracked += 1
                if self.momentum is None:  # use cumulative moving average
                    exponential_average_factor = 1.0 / float(self.num_batches_tracked)
                else:  # use exponential moving average
                    exponential_average_factor = self.momentum
                    
        # 计算当前批次的统计量
        if self.training and (self.process_group is not None):
            # 关闭自动梯度以提高性能
            with torch.no_grad():
                # 计算每个设备上的局部统计量
                world_size = torch.distributed.get_world_size(self.process_group)
                
                # 计算均值和方差
                dims = [0] + list(range(2, input.dim()))
                mean = input.mean(dims)
                var = input.var(dims, unbiased=False)
                n = input.numel() / input.size(1)
                
                # 同步统计量
                mean, var, n = self._sync_stats(mean, var, torch.tensor([n], device=input.device))
                
                # 更新运行统计量
                running_mean, running_var = self._get_running_stats()
                running_mean.mul_(1 - exponential_average_factor).add_(mean * exponential_average_factor)
                running_var.mul_(1 - exponential_average_factor).add_(var * exponential_average_factor)
                
                # 使用同步统计量进行归一化
                input = (input - mean[None, :, None, None]) / (torch.sqrt(var[None, :, None, None] + self.eps))
                
                # 应用仿射变换
                if self.affine:
                    input = input * self.weight[None, :, None, None] + self.bias[None, :, None, None]
                    
                return input
        else:
            # 非同步模式或推理模式
            if self.domain_label == 0:
                return F.batch_norm(
                    input, self.running_mean_source, self.running_var_source, self.weight, self.bias,
                    self.training or not self.track_running_stats,
                    exponential_average_factor, self.eps)
            elif self.domain_label == 1:
                return F.batch_norm(
                    input, self.running_mean_target, self.running_var_target, self.weight, self.bias,
                    self.training or not self.track_running_stats,
                    exponential_average_factor, self.eps)
            else:
                return F.batch_norm(
                    input, self.running_mean_mid, self.running_var_mid, self.weight, self.bias,
                    self.training or not self.track_running_stats,
                    exponential_average_factor, self.eps)

    def extra_repr(self):
        return '{num_features}, eps={eps}, momentum={momentum}, affine={affine}, ' \
               'track_running_stats={track_running_stats}'.format(**self.__dict__)

    def _load_from_state_dict(self, state_dict, prefix, metadata, strict,
                              missing_keys, unexpected_keys, error_msgs):
        version = metadata.get('version', None)
        if (version is None or version < 2) and self.track_running_stats:
            # at version 2: added num_batches_tracked buffer
            #               this should have a default value of 0
            num_batches_tracked_key = prefix + 'num_batches_tracked'
            if num_batches_tracked_key not in state_dict:
                state_dict[num_batches_tracked_key] = torch.tensor(0, dtype=torch.long)

        self._load_from_state_dict_ds(
            state_dict, prefix, metadata, strict,
            missing_keys, unexpected_keys, error_msgs)

    def _load_from_state_dict_ds(self, state_dict, prefix, local_metadata, strict,
                                 missing_keys, unexpected_keys, error_msgs):
        r"""Copies parameters and buffers from :attr:`state_dict` into only
        this module, but not its descendants. This is called on every submodule
        in :meth:`~torch.nn.Module.load_state_dict`. Metadata saved for this
        module in input :attr:`state_dict` is provided as :attr:`local_metadata`.
        For state dicts without metadata, :attr:`local_metadata` is empty.
        Subclasses can achieve class-specific backward compatible loading using
        the version number at `local_metadata.get("version", None)`.

        .. note::
            :attr:`state_dict` is not the same object as the input
            :attr:`state_dict` to :meth:`~torch.nn.Module.load_state_dict`. So
            it can be modified.

        Arguments:
            state_dict (dict): a dict containing parameters and
                persistent buffers.
            prefix (str): the prefix for parameters and buffers used in this
                module
            local_metadata (dict): a dict containing the metadata for this module.
                See
            strict (bool): whether to strictly enforce that the keys in
                :attr:`state_dict` with :attr:`prefix` match the names of
                parameters and buffers in this module
            missing_keys (list of str): if ``strict=True``, add missing keys to
                this list
            unexpected_keys (list of str): if ``strict=True``, add unexpected
                keys to this list
            error_msgs (list of str): error messages should be added to this
                list, and will be reported together in
                :meth:`~torch.nn.Module.load_state_dict`
        """

        local_name_params = itertools.chain(self._parameters.items(), self._buffers.items())
        local_state = {k: v.data for k, v in local_name_params if v is not None}
        for name, param in local_state.items():
            key = prefix + name
            # if ('source' in key or 'target' in key) and (key not in state_dict):
            if ('source' in key or 'target' in key) and (key not in state_dict):
                key = key[:-7]
            elif ('mid' in key) and (key not in state_dict):
                key = key[:-4]
            # import ipdb; ipdb.set_trace(context=20)
            if key in state_dict:
                input_param = state_dict[key]

                # Backward compatibility: loading 1-dim tensor from 0.3.* to version 0.4+
                if len(param.shape) == 0 and len(input_param.shape) == 1:
                    input_param = input_param[0]

                if input_param.shape != param.shape:
                    # local shape should match the one in checkpoint
                    error_msgs.append('size mismatch for {}: copying a param with shape {} from checkpoint, '
                                      'the shape in current model is {}.'
                                      .format(key, input_param.shape, param.shape))
                    continue

                if isinstance(input_param, Parameter):
                    # backwards compatibility for serialized parameters
                    input_param = input_param.data
                try:
                    param.copy_(input_param)
                except Exception:
                    error_msgs.append('While copying the parameter named "{}", '
                                      'whose dimensions in the model are {} and '
                                      'whose dimensions in the checkpoint are {}.'
                                      .format(key, param.size(), input_param.size()))
            elif strict:
                missing_keys.append(key)

        # if strict:
        #     for key in state_dict.keys():
        #         if key.startswith(prefix):
        #             input_name = key[len(prefix):]
        #             input_name = input_name.split('.', 1)[0]  # get the name of param/buffer/child
        #             if input_name not in self._modules and input_name not in local_state:
        #                 unexpected_keys.append(key)

    @classmethod
    def convert_dsnorm_sync(cls, module, process_group=None):
        r"""Helper function to convert `torch.nn.BatchNormND` layer in the model to
        `DSNormSync` layer.

        Args:
            module (nn.Module): containing module
            process_group (optional): process group to be used for SyncBatchNorm

        Returns:
            The original module with the converted `DSNormSync` layer

        Example::

            >>> # Network with nn.BatchNorm layer
            >>> module = torch.nn.Sequential(
            >>>            torch.nn.Linear(20, 100),
            >>>            torch.nn.BatchNorm1d(100)
            >>>          ).cuda()
            >>> # convert to DSNormSync
            >>> sync_bn_module = DSNormSync.convert_dsnorm_sync(module, process_group)
        """
        module_output = module
        if isinstance(module, (torch.nn.modules.batchnorm._BatchNorm, DSNorm)):
            module_output = cls(module.num_features,
                            module.eps, module.momentum,
                            module.affine,
                            module.track_running_stats,
                            process_group)
            if module.affine:
                module_output.weight.data = module.weight.data.clone().detach()
                module_output.bias.data = module.bias.data.clone().detach()
                # keep reuqires_grad unchanged
                module_output.weight.requires_grad = module.weight.requires_grad
                module_output.bias.requires_grad = module.bias.requires_grad

            # 处理不同类型的输入模块
            if isinstance(module, DSNorm):
                module_output.running_mean_target = module.running_mean_target
                module_output.running_var_target = module.running_var_target
                module_output.running_mean_source = module.running_mean_source
                module_output.running_var_source = module.running_var_source
                module_output.running_mean_mid = module.running_mean_mid
                module_output.running_var_mid = module.running_var_mid
            else:
                module_output.running_mean_target = module_output.running_mean_source = module.running_mean
                module_output.running_var_target = module_output.running_var_source = module.running_var
                module_output.running_mean_mid = module.running_mean
                module_output.running_var_mid = module.running_var
                
            module_output.num_batches_tracked = module.num_batches_tracked

        for name, child in module.named_children():
            module_output.add_module(name, cls.convert_dsnorm_sync(child, process_group))
        del module
        return module_output


class DSNorm(Module):
    def __init__(self, num_features, eps=1e-5, momentum=0.1, affine=True, track_running_stats=True):
        super(DSNorm, self).__init__()
        self.num_features = num_features
        self.eps = eps
        self.momentum = momentum
        self.affine = affine
        self.track_running_stats = track_running_stats
        # 0 indicate source, 1 indicate target, 2 indicate mid
        self.domain_label = 0

        if self.affine:
            self.weight = Parameter(torch.Tensor(num_features))
            self.bias = Parameter(torch.Tensor(num_features))
        else:
            self.register_parameter('weight', None)
            self.register_parameter('bias', None)

        if self.track_running_stats:
            self.register_buffer('running_mean_source', torch.zeros(num_features))
            self.register_buffer('running_mean_target', torch.zeros(num_features))
            self.register_buffer('running_var_source', torch.ones(num_features))
            self.register_buffer('running_var_target', torch.ones(num_features))
            self.register_buffer('num_batches_tracked', torch.tensor(0, dtype=torch.long))
            
            self.register_buffer('running_mean_mid', torch.zeros(num_features))
            self.register_buffer('running_var_mid', torch.ones(num_features))
        else:
            self.register_parameter('running_mean_source', None)
            self.register_parameter('running_mean_target', None)
            self.register_parameter('running_var_source', None)
            self.register_parameter('running_var_target', None)

            self.register_parameter('running_mean_mid', None)
            self.register_parameter('running_var_mid', None)
        self.reset_parameters()

    def reset_running_stats(self):
        if self.track_running_stats:
            self.running_mean_source.zero_()
            self.running_var_source.fill_(1)
            self.running_mean_target.zero_()
            self.running_var_target.fill_(1)
            self.num_batches_tracked.zero_()

            self.running_mean_mid.zero_()
            self.running_var_mid.fill_(1)

    def set_domain_label(self, domain_label):
        self.domain_label = domain_label

    def reset_parameters(self):
        self.reset_running_stats()
        if self.affine:
            nn.init.ones_(self.weight)
            nn.init.zeros_(self.bias)

    def _check_input_dim(self, input):
        return NotImplemented

    def forward(self, input):
        self._check_input_dim(input)

        if self.momentum is None:
            exponential_average_factor = 0.0
        else:
            exponential_average_factor = self.momentum
        # import ipdb; ipdb.set_trace(context=20)
        if self.training and self.track_running_stats:
            # TODO: if statement only here to tell the jit to skip emitting this when it is None
            if self.num_batches_tracked is not None:
                self.num_batches_tracked += 1
                if self.momentum is None:  # use cumulative moving average
                    exponential_average_factor = 1.0 / float(self.num_batches_tracked)
                else:  # use exponential moving average
                    exponential_average_factor = self.momentum

        # return F.batch_norm(
        #     input, self.running_mean_target if self.domain_label else self.running_mean_source,
        #     self.running_var_target if self.domain_label else self.running_var_source, self.weight, self.bias,
        #     self.training or not self.track_running_stats,
        #     exponential_average_factor, self.eps)
        if self.domain_label == 0:
            return F.batch_norm(
                input, self.running_mean_source, self.running_var_source, self.weight, self.bias,
                self.training or not self.track_running_stats,
                exponential_average_factor, self.eps)
        elif self.domain_label == 1:
            return F.batch_norm(
                input, self.running_mean_target, self.running_var_target, self.weight, self.bias,
                self.training or not self.track_running_stats,
                exponential_average_factor, self.eps)
        else:
            return F.batch_norm(
                input, self.running_mean_mid, self.running_var_mid, self.weight, self.bias,
                self.training or not self.track_running_stats,
                exponential_average_factor, self.eps)

    def extra_repr(self):
        return '{num_features}, eps={eps}, momentum={momentum}, affine={affine}, ' \
               'track_running_stats={track_running_stats}'.format(**self.__dict__)

    def _load_from_state_dict(self, state_dict, prefix, metadata, strict,
                              missing_keys, unexpected_keys, error_msgs):
        version = metadata.get('version', None)
        if (version is None or version < 2) and self.track_running_stats:
            # at version 2: added num_batches_tracked buffer
            #               this should have a default value of 0
            num_batches_tracked_key = prefix + 'num_batches_tracked'
            if num_batches_tracked_key not in state_dict:
                state_dict[num_batches_tracked_key] = torch.tensor(0, dtype=torch.long)

        self._load_from_state_dict_ds(
            state_dict, prefix, metadata, strict,
            missing_keys, unexpected_keys, error_msgs)

    def _load_from_state_dict_ds(self, state_dict, prefix, local_metadata, strict,
                                 missing_keys, unexpected_keys, error_msgs):
        r"""Copies parameters and buffers from :attr:`state_dict` into only
        this module, but not its descendants. This is called on every submodule
        in :meth:`~torch.nn.Module.load_state_dict`. Metadata saved for this
        module in input :attr:`state_dict` is provided as :attr:`local_metadata`.
        For state dicts without metadata, :attr:`local_metadata` is empty.
        Subclasses can achieve class-specific backward compatible loading using
        the version number at `local_metadata.get("version", None)`.

        .. note::
            :attr:`state_dict` is not the same object as the input
            :attr:`state_dict` to :meth:`~torch.nn.Module.load_state_dict`. So
            it can be modified.

        Arguments:
            state_dict (dict): a dict containing parameters and
                persistent buffers.
            prefix (str): the prefix for parameters and buffers used in this
                module
            local_metadata (dict): a dict containing the metadata for this module.
                See
            strict (bool): whether to strictly enforce that the keys in
                :attr:`state_dict` with :attr:`prefix` match the names of
                parameters and buffers in this module
            missing_keys (list of str): if ``strict=True``, add missing keys to
                this list
            unexpected_keys (list of str): if ``strict=True``, add unexpected
                keys to this list
            error_msgs (list of str): error messages should be added to this
                list, and will be reported together in
                :meth:`~torch.nn.Module.load_state_dict`
        """

        local_name_params = itertools.chain(self._parameters.items(), self._buffers.items())
        local_state = {k: v.data for k, v in local_name_params if v is not None}
        for name, param in local_state.items():
            key = prefix + name
            # if ('source' in key or 'target' in key) and (key not in state_dict):
            if ('source' in key or 'target' in key) and (key not in state_dict):
                key = key[:-7]
            elif ('mid' in key) and (key not in state_dict):
                key = key[:-4]
            # import ipdb; ipdb.set_trace(context=20)
            if key in state_dict:
                input_param = state_dict[key]

                # Backward compatibility: loading 1-dim tensor from 0.3.* to version 0.4+
                if len(param.shape) == 0 and len(input_param.shape) == 1:
                    input_param = input_param[0]

                if input_param.shape != param.shape:
                    # local shape should match the one in checkpoint
                    error_msgs.append('size mismatch for {}: copying a param with shape {} from checkpoint, '
                                      'the shape in current model is {}.'
                                      .format(key, input_param.shape, param.shape))
                    continue

                if isinstance(input_param, Parameter):
                    # backwards compatibility for serialized parameters
                    input_param = input_param.data
                try:
                    param.copy_(input_param)
                except Exception:
                    error_msgs.append('While copying the parameter named "{}", '
                                      'whose dimensions in the model are {} and '
                                      'whose dimensions in the checkpoint are {}.'
                                      .format(key, param.size(), input_param.size()))
            elif strict:
                missing_keys.append(key)

        # if strict:
        #     for key in state_dict.keys():
        #         if key.startswith(prefix):
        #             input_name = key[len(prefix):]
        #             input_name = input_name.split('.', 1)[0]  # get the name of param/buffer/child
        #             if input_name not in self._modules and input_name not in local_state:
        #                 unexpected_keys.append(key)

    @classmethod
    def convert_dsnorm(cls, module):
        r"""Helper function to convert `torch.nn.BatchNormND` layer in the model to
        `torch.nn.SyncBatchNorm` layer.

        Args:
            module (nn.Module): containing module

        Returns:
            The original module with the converted `torch.nn.SyncBatchNorm` layer

        Example::

            >>> # Network with nn.BatchNorm layer
            >>> module = torch.nn.Sequential(
            >>>            torch.nn.Linear(20, 100),
            >>>            torch.nn.BatchNorm1d(100)
            >>>          ).cuda()
            >>> # creating process group (optional)
            >>> # process_ids is a list of int identifying rank ids.

        """
        module_output = module
        if isinstance(module, torch.nn.modules.batchnorm._BatchNorm):
            module_output = DSNorm(module.num_features,
                                   module.eps, module.momentum,
                                   module.affine,
                                   module.track_running_stats)
            if module.affine:
                module_output.weight.data = module.weight.data.clone().detach()
                module_output.bias.data = module.bias.data.clone().detach()
                # keep reuqires_grad unchanged
                module_output.weight.requires_grad = module.weight.requires_grad
                module_output.bias.requires_grad = module.bias.requires_grad

            module_output.running_mean_target = module_output.running_mean_source = module.running_mean
            module_output.running_var_target = module_output.running_var_source = module.running_var
            module_output.num_batches_tracked = module.num_batches_tracked

            module_output.running_mean_mid = module.running_mean
            module_output.running_var_mid = module.running_var

        for name, child in module.named_children():
            module_output.add_module(name, cls.convert_dsnorm(child))
        del module
        return module_output


class DSNorm1dSync(DSNormSync):
    r"""应用同步批归一化到 2D 或 3D 输入 (批次大小为N的1D输入，可选择额外的通道维度)"""

    def _check_input_dim(self, input):
        if input.dim() != 2 and input.dim() != 3:
            raise ValueError('expected 2D or 3D input (got {}D input)'
                             .format(input.dim()))


class DSNorm2dSync(DSNormSync):
    r"""应用同步批归一化到 4D 输入 (批次大小为N的2D输入，带额外的通道维度)"""

    def _check_input_dim(self, input):
        if input.dim() != 4:
            raise ValueError('expected 4D input (got {}D input)'
                             .format(input.dim()))


class DSNorm1d(DSNorm):
    r"""Applies Batch Normalization over a 2D or 3D input (a mini-batch of 1D
    inputs with optional additional channel dimension) as described in the paper
    `Batch Normalization: Accelerating Deep Network Training by Reducing Internal Covariate Shift`_ .

    .. math::

        y = \frac{x - \mathrm{E}[x]}{\sqrt{\mathrm{Var}[x] + \epsilon}} * \gamma + \beta

    The mean and standard-deviation are calculated per-dimension over
    the mini-batches and :math:`\gamma` and :math:`\beta` are learnable parameter vectors
    of size `C` (where `C` is the input size).

    By default, during training this layer keeps running estimates of its
    computed mean and variance, which are then used for normalization during
    evaluation. The running estimates are kept with a default :attr:`momentum`
    of 0.1.

    If :attr:`track_running_stats` is set to ``False``, this layer then does not
    keep running estimates, and batch statistics are instead used during
    evaluation time as well.

    .. note::
        This :attr:`momentum` argument is different from one used in optimizer
        classes and the conventional notion of momentum. Mathematically, the
        update rule for running statistics here is
        :math:`\hat{x}_\text{new} = (1 - \text{momentum}) \times \hat{x} + \text{momemtum} \times x_t`,
        where :math:`\hat{x}` is the estimated statistic and :math:`x_t` is the
        new observed value.

    Because the Batch Normalization is done over the `C` dimension, computing statistics
    on `(N, L)` slices, it's common terminology to call this Temporal Batch Normalization.

    Args:
        num_features: :math:`C` from an expected input of size
            :math:`(N, C, L)` or :math:`L` from input of size :math:`(N, L)`
        eps: a value added to the denominator for numerical stability.
            Default: 1e-5
        momentum: the value used for the running_mean and running_var
            computation. Can be set to ``None`` for cumulative moving average
            (i.e. simple average). Default: 0.1
        affine: a boolean value that when set to ``True``, this module has
            learnable affine parameters. Default: ``True``
        track_running_stats: a boolean value that when set to ``True``, this
            module tracks the running mean and variance, and when set to ``False``,
            this module does not track such statistics and always uses batch
            statistics in both training and eval modes. Default: ``True``

    Shape:
        - Input: :math:`(N, C)` or :math:`(N, C, L)`
        - Output: :math:`(N, C)` or :math:`(N, C, L)` (same shape as input)

    Examples::

        >>> # With Learnable Parameters
        >>> m = nn.BatchNorm1d(100)
        >>> # Without Learnable Parameters
        >>> m = nn.BatchNorm1d(100, affine=False)
        >>> input = torch.randn(20, 100)
        >>> output = m(input)

    .. _`Batch Normalization: Accelerating Deep Network Training by Reducing Internal Covariate Shift`:
        https://arxiv.org/abs/1502.03167
    """

    def _check_input_dim(self, input):
        if input.dim() != 2 and input.dim() != 3:
            raise ValueError('expected 2D or 3D input (got {}D input)'
                             .format(input.dim()))


class DSNorm2d(DSNorm):
    r"""Applies Batch Normalization over a 4D input (a mini-batch of 2D inputs
    with additional channel dimension) as described in the paper
    `Batch Normalization: Accelerating Deep Network Training by Reducing Internal Covariate Shift`_ .

    .. math::

        y = \frac{x - \mathrm{E}[x]}{ \sqrt{\mathrm{Var}[x] + \epsilon}} * \gamma + \beta

    The mean and standard-deviation are calculated per-dimension over
    the mini-batches and :math:`\gamma` and :math:`\beta` are learnable parameter vectors
    of size `C` (where `C` is the input size).

    By default, during training this layer keeps running estimates of its
    computed mean and variance, which are then used for normalization during
    evaluation. The running estimates are kept with a default :attr:`momentum`
    of 0.1.

    If :attr:`track_running_stats` is set to ``False``, this layer then does not
    keep running estimates, and batch statistics are instead used during
    evaluation time as well.

    .. note::
        This :attr:`momentum` argument is different from one used in optimizer
        classes and the conventional notion of momentum. Mathematically, the
        update rule for running statistics here is
        :math:`\hat{x}_\text{new} = (1 - \text{momentum}) \times \hat{x} + \text{momemtum} \times x_t`,
        where :math:`\hat{x}` is the estimated statistic and :math:`x_t` is the
        new observed value.

    Because the Batch Normalization is done over the `C` dimension, computing statistics
    on `(N, H, W)` slices, it's common terminology to call this Spatial Batch Normalization.

    Args:
        num_features: :math:`C` from an expected input of size
            :math:`(N, C, H, W)`
        eps: a value added to the denominator for numerical stability.
            Default: 1e-5
        momentum: the value used for the running_mean and running_var
            computation. Can be set to ``None`` for cumulative moving average
            (i.e. simple average). Default: 0.1
        affine: a boolean value that when set to ``True``, this module has
            learnable affine parameters. Default: ``True``
        track_running_stats: a boolean value that when set to ``True``, this
            module tracks the running mean and variance, and when set to ``False``,
            this module does not track such statistics and always uses batch
            statistics in both training and eval modes. Default: ``True``

    Shape:
        - Input: :math:`(N, C, H, W)`
        - Output: :math:`(N, C, H, W)` (same shape as input)

    Examples::

        >>> # With Learnable Parameters
        >>> m = nn.BatchNorm2d(100)
        >>> # Without Learnable Parameters
        >>> m = nn.BatchNorm2d(100, affine=False)
        >>> input = torch.randn(20, 100, 35, 45)
        >>> output = m(input)

    .. _`Batch Normalization: Accelerating Deep Network Training by Reducing Internal Covariate Shift`:
        https://arxiv.org/abs/1502.03167
    """

    def _check_input_dim(self, input):
        if input.dim() != 4:
            raise ValueError('expected 4D input (got {}D input)'
                             .format(input.dim()))


def set_ds_source(m):
    classname = m.__class__.__name__
    if classname.find('DSNorm') != -1:
        m.set_domain_label(0)


def set_ds_target(m):
    classname = m.__class__.__name__
    if classname.find('DSNorm') != -1:
        m.set_domain_label(1)


def set_ds_mid(m):
    classname = m.__class__.__name__
    if classname.find('DSNorm') != -1:
        m.set_domain_label(2)


def convert_model_to_dssynnorm(module, process_group=None):
    """
    将模型中的所有BatchNorm和DSNorm层转换为DSNormSync层
    
    Args:
        module (nn.Module): 包含需要转换层的模型
        process_group (optional): 用于同步的进程组
        
    Returns:
        转换后的模型
    """
    module_output = module
    if isinstance(module, torch.nn.modules.batchnorm._BatchNorm):
        if module.num_features == 1:
            # 某些1特征的batchnorm层可能不需要同步
            return module
        # BatchNorm -> DSNormSync
        if isinstance(module, torch.nn.BatchNorm1d):
            module_output = DSNorm1dSync(module.num_features, module.eps,
                                    module.momentum, module.affine,
                                    module.track_running_stats, process_group)
        elif isinstance(module, torch.nn.BatchNorm2d):
            module_output = DSNorm2dSync(module.num_features, module.eps,
                                    module.momentum, module.affine,
                                    module.track_running_stats, process_group)
        else:
            # 不支持的BatchNorm类型
            return module
            
        if module.affine:
            module_output.weight.data = module.weight.data.clone().detach()
            module_output.bias.data = module.bias.data.clone().detach()
            module_output.weight.requires_grad = module.weight.requires_grad
            module_output.bias.requires_grad = module.bias.requires_grad
            
        module_output.running_mean_source = module_output.running_mean_target = module_output.running_mean_mid = module.running_mean
        module_output.running_var_source = module_output.running_var_target = module_output.running_var_mid = module.running_var
        module_output.num_batches_tracked = module.num_batches_tracked
    elif isinstance(module, DSNorm):
        # DSNorm -> DSNormSync
        if isinstance(module, DSNorm1d):
            module_output = DSNorm1dSync(module.num_features, module.eps,
                                    module.momentum, module.affine,
                                    module.track_running_stats, process_group)
        elif isinstance(module, DSNorm2d):
            module_output = DSNorm2dSync(module.num_features, module.eps,
                                    module.momentum, module.affine,
                                    module.track_running_stats, process_group)
        else:
            return module
            
        if module.affine:
            module_output.weight.data = module.weight.data.clone().detach()
            module_output.bias.data = module.bias.data.clone().detach()
            module_output.weight.requires_grad = module.weight.requires_grad
            module_output.bias.requires_grad = module.bias.requires_grad
            
        module_output.running_mean_source = module.running_mean_source
        module_output.running_var_source = module.running_var_source
        module_output.running_mean_target = module.running_mean_target
        module_output.running_var_target = module.running_var_target
        module_output.running_mean_mid = module.running_mean_mid
        module_output.running_var_mid = module.running_var_mid
        module_output.num_batches_tracked = module.num_batches_tracked
        
    for name, child in module.named_children():
        module_output.add_module(name, convert_model_to_dssynnorm(child, process_group))
        
    return module_output
