import torch
import torch.nn as nn
import torch.nn.functional as F
import math
tab4_str = '\t\t\t\t'  # used for aligning code
curly_bracket_l = '{'
curly_bracket_r = '}'

def heaviside(x: torch.Tensor):
    '''
    * :ref:`API in English <heaviside.__init__-en>`
    .. _heaviside.__init__-cn:

    :param x: 输入tensor
    :return: 输出tensor

    heaviside阶跃函数，定义为

    .. math::
        g(x) =
        \\begin{cases}
        1, & x \\geq 0 \\\\
        0, & x < 0 \\\\
        \\end{cases}

    阅读 `HeavisideStepFunction <https://mathworld.wolfram.com/HeavisideStepFunction.html>`_ 以获得更多信息。

    * :ref:`中文API <heaviside.__init__-cn>`
    .. _heaviside.__init__-en:

    :param x: the input tensor
    :return: the output tensor

    The heaviside function, which is defined by

    .. math::
        g(x) =
        \\begin{cases}
        1, & x \\geq 0 \\\\
        0, & x < 0 \\\\
        \\end{cases}

    For more information, see `HeavisideStepFunction <https://mathworld.wolfram.com/HeavisideStepFunction.html>`_.

    '''
    return (x >= 0).to(x)

def check_manual_grad(primitive_function, spiking_function, eps=1e-5):
    '''
    :param primitive_function: 梯度替代函数的原函数
    :type primitive_function: callable
    :param spiking_function: 梯度替代函数
    :type spiking_function: callable
    :param eps: 最大误差
    :type eps: float

    梯度替代函数的反向传播一般是手写的，可以用此函数去检查手写梯度是否正确。

    此函数检查梯度替代函数spiking_function的反向传播，与原函数primitive_function的反向传播结果是否一致。“一致”被定义为，两者的误差不超过eps。

    示例代码：

    .. code-block:: python

        surrogate.check_manual_grad(surrogate.ATan.primitive_function, surrogate.atan.apply)
    '''
    alpha = torch.tensor(1.0, dtype=torch.float)
    x = torch.arange(-16, 16, 32 / 8192)
    x.requires_grad_(True)
    primitive_function(x, alpha).sum().backward()
    x_grad_auto = x.grad.clone()
    x.grad.zero_()
    spiking_function(x, alpha).sum().backward()
    x_grad_manual = x.grad.clone()
    assert (x_grad_manual - x_grad_auto).abs().max().item() <= eps, 'x.grad is wrong!'
    print('grad check pass')

class FGTSurrogateFunctionBase(nn.Module):
    def __init__(self, alpha, spiking=True):
        super().__init__()
        self.spiking = spiking
        self.alpha = alpha

    def set_spiking_mode(self, spiking: bool):
        self.spiking = spiking

    def extra_repr(self):
        return f'alpha={self.alpha}, spiking={self.spiking}'

    @staticmethod
    def spiking_function(x, alpha):
        raise NotImplementedError

    @staticmethod
    def primitive_function(x, alpha):
        raise NotImplementedError

    def cuda_code(self, x: str, y: str, dtype='fp32'):
        raise NotImplementedError

    def cuda_code_start_comments(self):
        return f'// start: spikingjelly.clock_driven.surrogate.{self._get_name()}.cuda_code'

    def cuda_code_end_comments(self):
        return f'// end: spikingjelly.clock_driven.surrogate.{self._get_name()}.cuda_code'

    def forward(self, x: torch.Tensor, forward_grad=False):
        if self.spiking:
            return self.spiking_function(x, self.alpha, forward_grad=forward_grad)
        else:
            return self.primitive_function(x, self.alpha, forward_grad=forward_grad)


#class MultiArgsSurrogateFunctionBase(nn.Module):
#    def __init__(self, spiking: bool, *args, **kwargs):
#        super().__init__()
#        self.spiking = spiking
#
#    def set_spiking_mode(self, spiking: bool):
#        self.spiking = spiking
#
#    def cuda_code(self, x: str, y: str, dtype='fp32'):
#        raise NotImplementedError
#
#    def cuda_code_start_comments(self):
#        return f'// start: spikingjelly.clock_driven.surrogate.{self._get_name()}.cuda_code'
#
#    def cuda_code_end_comments(self):
#        return f'// end: spikingjelly.clock_driven.surrogate.{self._get_name()}.cuda_code'


class piecewise_quadratic(torch.autograd.Function):
    @staticmethod
    def forward(ctx, x, alpha):
        if x.requires_grad:
            ctx.save_for_backward(x, alpha)
        return heaviside(x)

    @staticmethod
    def backward(ctx, grad_output):
        grad_x = None
        if ctx.needs_input_grad[0]:
            x_abs = ctx.saved_tensors[0].abs()
            mask = (x_abs > (1 / ctx.alpha))
            grad_x = (grad_output * (- (ctx.alpha ** 2) * x_abs + ctx.alpha)).masked_fill_(mask, 0)
        return grad_x, None


class FGTPiecewiseQuadratic(FGTSurrogateFunctionBase):
    def __init__(self, alpha=1.0, spiking=True):
        super().__init__(alpha, spiking)

    @staticmethod
    def spiking_function(x, alpha, forward_grad=False):
        if not forward_grad:
            return piecewise_quadratic.apply(x, alpha)
        else:
            x_abs = x.abs()
            mask = (x_abs > (1. / alpha))
            sg = (-(alpha ** 2) * x.abs() + alpha).masked_fill_(mask, 0)
            return heaviside(x), sg

    @staticmethod
    def primitive_function(x: torch.Tensor, alpha, forward_grad=False):
        mask0 = (x > (1.0 / alpha)).to(x)
        mask1 = (x.abs() <= (1.0 / alpha)).to(x)

        out = mask0 + mask1 * (-(alpha ** 2) / 2 * x.square() * x.sign() + alpha * x + 0.5)
        if not forward_grad:
            return out
        else:
            x_abs = x.abs()
            mask = (x_abs > (1. / alpha))
            sg = (-(alpha ** 2) * x.abs() + alpha).masked_fill_(mask, 0)
            return out, sg


class piecewise_exp(torch.autograd.Function):
    @staticmethod
    def forward(ctx, x, alpha):
        if x.requires_grad:
            ctx.save_for_backward(x)
            ctx.alpha = alpha
        return heaviside(x)

    @staticmethod
    def backward(ctx, grad_output):
        grad_x = None
        if ctx.needs_input_grad[0]:
            grad_x = ctx.alpha / 2 * (- ctx.alpha * ctx.saved_tensors[0].abs()).exp_() * grad_output

        return grad_x, None


class FGTPiecewiseExp(FGTSurrogateFunctionBase):
    def __init__(self, alpha=1.0, spiking=True):
        super().__init__(alpha, spiking)

    @staticmethod
    def spiking_function(x, alpha, forward_grad=False):
        if not forward_grad:
            return piecewise_exp.apply(x, alpha)
        else:
            sg = alpha / 2 * (-alpha * x.abs()).exp_()
            return heaviside(x), sg

    @staticmethod
    def primitive_function(x: torch.Tensor, alpha, forward_grad=False):
        mask_nonnegative = heaviside(x)
        mask_sign = mask_nonnegative * 2 - 1
        exp_x = (mask_sign * x * -alpha).exp_() / 2

        out = mask_nonnegative - exp_x * mask_sign
        if not forward_grad:
            return out
        else:
            sg = alpha / 2 * (-alpha * x.abs()).exp_()
            return out, sg


class sigmoid(torch.autograd.Function):
    @staticmethod
    def forward(ctx, x, alpha, noise=None):
        if x.requires_grad:
            ctx.save_for_backward(x)
            ctx.alpha = alpha
        if noise is not None:
            return heaviside(x + noise)
        return heaviside(x)

    @staticmethod
    def backward(ctx, grad_output):
        grad_x = None
        if ctx.needs_input_grad[0]:
            sgax = (ctx.saved_tensors[0] * ctx.alpha).sigmoid_()
            grad_x = grad_output * (1. - sgax) * sgax * ctx.alpha

        return grad_x, None, None


class FGTSigmoid(FGTSurrogateFunctionBase):
    def __init__(self, alpha=4.0, spiking=True, stochastic=False):
        super().__init__(alpha, spiking)
        self.stochastic = stochastic


    #@staticmethod
    def spiking_function(self, x, alpha, forward_grad=False):
        if self.stochastic:
            noise = torch.rand(size=x.shape, device=x.device)
            noise = torch.log(noise / (1 - noise + 1e-6) + 1e-6) / alpha
        else:
            noise = None

        if not forward_grad:
            return sigmoid.apply(x, alpha, noise)
        else:
            sgax = (x * alpha).sigmoid_()
            sg = (1. - sgax) * sgax * alpha
            if noise is not None:
                return heaviside(x + noise), sg
            return heaviside(x), sg

    @staticmethod
    def primitive_function(x: torch.Tensor, alpha, forward_grad=False):
        out = (x * alpha).sigmoid()
        if not forward_grad:
            return out
        else:
            sgax = (x * alpha).sigmoid_()
            sg = (1. - sgax) * sgax * alpha
            return out, sg

    def cuda_code(self, x: str, y: str, dtype='fp32'):
        sg_name = 'sg_' + self._get_name()
        alpha = str(self.alpha) + 'f'
        code = f'''
            {tab4_str}{self.cuda_code_start_comments()}
        '''

        if dtype == 'fp32':
            code += f'''
            {tab4_str}const float {sg_name}_sigmoid_ax = 1.0f / (1.0f + expf(- {alpha} * {x}));
            {tab4_str}const float {y} = (1.0f - {sg_name}_sigmoid_ax) * {sg_name}_sigmoid_ax * {alpha};
            '''
        elif dtype == 'fp16':
            code += f'''
            {tab4_str}const half2 {sg_name}_alpha = __float2half2_rn({alpha});
            {tab4_str}const half2 {sg_name}_sigmoid_ax = __h2div(__float2half2_rn(1.0f), __hadd2(h2exp(__hneg2(__hmul2({sg_name}_alpha, {x}))), __float2half2_rn(1.0f)));
            {tab4_str}const half2 {y} = __hmul2(__hmul2(__hsub2(__float2half2_rn(1.0f), {sg_name}_sigmoid_ax), {sg_name}_sigmoid_ax), {sg_name}_alpha);
            '''
        else:
            raise NotImplementedError
        code += f'''
            {tab4_str}{self.cuda_code_end_comments()}
        '''
        return code


class soft_sign(torch.autograd.Function):
    @staticmethod
    def forward(ctx, x, alpha):
        if x.requires_grad:
            ctx.save_for_backward(x)
            ctx.alpha = alpha
        return heaviside(x)

    @staticmethod
    def backward(ctx, grad_output):
        grad_x = None
        if ctx.needs_input_grad[0]:
            grad_x = grad_output / (2 * ctx.alpha * (1 / ctx.alpha + ctx.saved_tensors[0].abs()).pow_(2))
        return grad_x, None


class FGTSoftSign(FGTSurrogateFunctionBase):
    def __init__(self, alpha=2.0, spiking=True):
        super().__init__(alpha, spiking)
        assert alpha > 0, 'alpha must be lager than 0'

    @staticmethod
    def spiking_function(x, alpha, forward_grad=False):
        if not forward_grad:
            return soft_sign.apply(x, alpha)
        else:
            sg = 1. / (2 * alpha * (1 / alpha + x.abs()).pow_(2))
            return heaviside(x), sg

    @staticmethod
    def primitive_function(x: torch.Tensor, alpha, forward_grad=False):
        out = (F.softsign(x * alpha) + 1) / 2
        if not forward_grad:
            return out
        else:
            sg = 1. / (2 * alpha * (1 / alpha + x.abs()).pow_(2))
            return out, sg


class atan(torch.autograd.Function):
    @staticmethod
    def forward(ctx, x, alpha):
        if x.requires_grad:
            ctx.save_for_backward(x)
            ctx.alpha = alpha
        return heaviside(x)

    @staticmethod
    def backward(ctx, grad_output):
        grad_x = None
        if ctx.needs_input_grad[0]:
            grad_x = ctx.alpha / 2 / (1 + (math.pi / 2 * ctx.alpha * ctx.saved_tensors[0]).pow_(2)) * grad_output

        return grad_x, None

class FGTATan(FGTSurrogateFunctionBase):
    def __init__(self, alpha=2.0, spiking=True):
        super().__init__(alpha, spiking)


    @staticmethod
    def spiking_function(x, alpha, forward_grad=False):
        if not forward_grad:
            return atan.apply(x, alpha)
        else:
            sg = alpha / 2 / (1 + (math.pi / 2 * alpha * x).pow_(2))
            return heaviside(x), sg

    @staticmethod
    def primitive_function(x: torch.Tensor, alpha, forward_grad=False):
        out = (math.pi / 2 * alpha * x).atan_() / math.pi + 0.5
        if not forward_grad:
            return out
        else:
            sg = alpha / 2 / (1 + (math.pi / 2 * alpha * x).pow_(2))
            return out, sg

    def cuda_code(self, x: str, y: str, dtype='fp32'):
        sg_name = 'sg_' + self._get_name()
        alpha = str(self.alpha) + 'f'
        code = f'''
            {tab4_str}{self.cuda_code_start_comments()}
        '''
        if dtype == 'fp32':
            code += f'''
            {tab4_str}const float {sg_name}_M_PI_2__alpha__x = ((float) 1.57079632679489661923) * {alpha} * {x};
            {tab4_str}const float {y} = {alpha} / 2.0f / (1.0f + {sg_name}_M_PI_2__alpha__x * {sg_name}_M_PI_2__alpha__x);
            '''
        elif dtype == 'fp16':
            code += f'''
            {tab4_str}const half2 {sg_name}_alpha =  __float2half2_rn({alpha});
            {tab4_str}const half2 {sg_name}_M_PI_2__alpha__x = __hmul2(__hmul2(__float2half2_rn((float) 1.57079632679489661923), {sg_name}_alpha), {x});
            {tab4_str}const half2 {y} = __h2div(__h2div({sg_name}_alpha, __float2half2_rn(2.0f)), __hfma2({sg_name}_M_PI_2__alpha__x, {sg_name}_M_PI_2__alpha__x, __float2half2_rn(1.0f)));
            '''
        else:
            raise NotImplementedError
        code += f'''
            {tab4_str}{self.cuda_code_end_comments()}
        '''
        return code


class nonzero_sign_log_abs(torch.autograd.Function):
    @staticmethod
    def forward(ctx, x, alpha):
        if x.requires_grad:
            ctx.save_for_backward(x)
            ctx.alpha = alpha
        return heaviside(x)

    @staticmethod
    def backward(ctx, grad_output):
        grad_x = None
        if ctx.needs_input_grad[0]:
            grad_x = grad_output / (1 / ctx.alpha + ctx.saved_tensors[0].abs())


        return grad_x, None


class FGTNonzeroSignLogAbs(FGTSurrogateFunctionBase):
    def __init__(self, alpha=1.0, spiking=True):
        super().__init__(alpha, spiking)


    @staticmethod
    def spiking_function(x, alpha, forward_grad=False):
        if not forward_grad:
            return nonzero_sign_log_abs.apply(x, alpha)
        else:
            sg = 1. / (1 / alpha + x.abs())
            return heaviside(x), sg

    @staticmethod
    def primitive_function(x: torch.Tensor, alpha, forward_grad=False):
        # the gradient of ``(heaviside(x) * 2 - 1) * (alpha * x.abs() + 1).log()`` by autograd is wrong at ``x==0``
        mask_p = heaviside(x) * 2 - 1
        out = mask_p * (alpha * mask_p * x + 1).log()
        if not forward_grad:
            return out
        else:
            sg = 1. / (1 / alpha + x.abs())
            return out, sg


class erf(torch.autograd.Function):
    @staticmethod
    def forward(ctx, x, alpha, noise=None):
        if x.requires_grad:
            ctx.save_for_backward(x)
            ctx.alpha = alpha
        if noise is not None:
            return heaviside(x + noise)
        return heaviside(x)

    @staticmethod
    def backward(ctx, grad_output):
        grad_x = None
        if ctx.needs_input_grad[0]:
            grad_x = grad_output * (- (ctx.saved_tensors[0] * ctx.alpha).pow_(2)).exp_() * (ctx.alpha / math.sqrt(math.pi))

        return grad_x, None, None


class FGTErf(FGTSurrogateFunctionBase):
    def __init__(self, alpha=2.0, spiking=True, stochastic=False):
        super().__init__(alpha, spiking)
        self.stochastic = stochastic


    #@staticmethod
    def spiking_function(self, x, alpha, forward_grad=False):
        if self.stochastic:
            noise = torch.normal(0, 1. / (alpha * math.sqrt(2)), size=x.shape, device=x.device)
        else:
            noise = None

        if not forward_grad:
            return erf.apply(x, alpha, noise)
        else:
            sg = (- (x * alpha).pow_(2)).exp_() * (alpha / math.sqrt(math.pi))
            if noise is not None:
                return heaviside(x + noise), sg
            return heaviside(x), sg

    @staticmethod
    def primitive_function(x: torch.Tensor, alpha, forward_grad=False):
        out = torch.erfc_(-alpha * x) / 2
        if not forward_grad:
            return out
        else:
            sg = (- (x * alpha).pow_(2)).exp_() * (alpha / math.sqrt(math.pi))
            return out, sg


class FGTReLU(FGTSurrogateFunctionBase):
    def __init__(self, alpha=1.0, spiking=True):
        super().__init__(alpha, spiking)


    @staticmethod
    def spiking_function(x, alpha, forward_grad=False):
        if not forward_grad:
            return torch.relu(x)
        else:
            sg = (x >= 0).float()
            return torch.relu(x), sg


class quasi_clamp(torch.autograd.Function):
    @staticmethod
    def forward(ctx, x, vth):
        if x.requires_grad:
            ctx.save_for_backward(x)
            ctx.vth = vth
        return heaviside(x)

    @staticmethod
    def backward(ctx, grad_output):
        grad_x = None
        if ctx.needs_input_grad[0]:
            x = ctx.saved_tensors[0]
            mask1 = (x.abs() > ctx.vth)
            mask_ = mask1.logical_not()
            grad_x = grad_output * x.masked_fill(mask_, 1.).masked_fill(mask1, 0.)
        return grad_x, None


class FGTQuasiClamp(FGTSurrogateFunctionBase):
    def __init__(self, alpha=1.0, spiking=True):
        super().__init__(alpha, spiking)

    @staticmethod
    def spiking_function(x, alpha, forward_grad=False):
        if not forward_grad:
            return quasi_clamp.apply(x, alpha)
        else:
            mask1 = (x.abs() > alpha)
            mask_ = mask1.logical_not()
            sg = x.masked_fill(mask_, 1.).masked_fill(mask1, 0.)
            return heaviside(x), sg

    @staticmethod
    def primitive_function(x: torch.Tensor, alpha, forward_grad=False):
        out = torch.clamp(x, - alpha, alpha)
        if not forward_grad:
            return out
        else:
            mask1 = (x.abs() > alpha)
            mask_ = mask1.logical_not()
            sg = x.masked_fill(mask_, 1.).masked_fill(mask1, 0.)
            return out, sg


#class piecewise_leaky_relu(torch.autograd.Function):
#    @staticmethod
#    def forward(ctx, x: torch.Tensor, w=1, c=0.01):
#        if x.requires_grad:
#            ctx.save_for_backward(x)
#            ctx.w = w
#            ctx.c = c
#        return heaviside(x)
#
#    @staticmethod
#    def backward(ctx, grad_output):
#        grad_x = None
#        if ctx.needs_input_grad[0]:
#            mask_width = (ctx.saved_tensors[0].abs() < ctx.w)
#            mask_c = mask_width.logical_not()
#            grad_x = grad_output * ctx.saved_tensors[0].masked_fill(mask_width, 1 / ctx.w).masked_fill(mask_c, ctx.c)
#        return grad_x, None, None
#
#
#class PiecewiseLeakyReLU(MultiArgsSurrogateFunctionBase):
#    def __init__(self, w=1., c=0.01, spiking=True):
#        '''
#        * :ref:`API in English <PiecewiseLeakyReLU.__init__-en>`
#        .. _PiecewiseLeakyReLU.__init__-cn:
#
#        :param w: ``-w <= x <= w`` 时反向传播的梯度为 ``1 / 2w``
#        :param c: ``x > w`` 或 ``x < -w`` 时反向传播的梯度为 ``c``
#        :param spiking: 是否输出脉冲，默认为 ``True``，在前向传播时使用 ``heaviside`` 而在反向传播使用替代梯度。若为 ``False``
#            则不使用替代梯度，前向传播时，使用反向传播时的梯度替代函数对应的原函数
#
#        分段线性的近似脉冲发放函数。梯度为
#
#        .. math::
#            g'(x) =
#            \\begin{cases}
#            \\frac{1}{w}, & -w \\leq x \\leq w \\\\
#            c, & x < -w ~or~ x > w
#            \\end{cases}
#
#        对应的原函数为
#
#        .. math::
#            g(x) =
#            \\begin{cases}
#            cx + cw, & x < -w \\\\
#            \\frac{1}{2w}x + \\frac{1}{2}, & -w \\leq x \\leq w \\\\
#            cx - cw + 1, & x > w \\\\
#            \\end{cases}
#
#        .. image:: ./_static/API/clock_driven/surrogate/PiecewiseLeakyReLU.*
#            :width: 100%
#
#        该函数在文章 [#yin2017algorithm]_ [#STBP]_ [#huh2018gradient]_ [#wu2019direct]_ [#STCA]_ [#roy2019scaling]_ [#LISNN]_ [#DECOLLE]_ 中使用。
#
#        * :ref:`中文API <PiecewiseLeakyReLU.__init__-cn>`
#        .. _PiecewiseLeakyReLU.__init__-en:
#
#        :param w: when ``-w <= x <= w`` the gradient is ``1 / 2w``
#        :param c: when ``x > w`` or ``x < -w`` the gradient is ``c``
#        :param spiking: whether output spikes. The default is ``True`` which means that using ``heaviside`` in forward
#            propagation and using surrogate gradient in backward propagation. If ``False``, in forward propagation,
#            using the primitive function of the surrogate gradient function used in backward propagation
#
#        The piecewise surrogate spiking function. The gradient is defined by
#
#        .. math::
#            g'(x) =
#            \\begin{cases}
#            \\frac{1}{w}, & -w \\leq x \\leq w \\\\
#            c, & x < -w ~or~ x > w
#            \\end{cases}
#
#        The primitive function is defined by
#
#        .. math::
#            g(x) =
#            \\begin{cases}
#            cx + cw, & x < -w \\\\
#            \\frac{1}{2w}x + \\frac{1}{2}, & -w \\leq x \\leq w \\\\
#            cx - cw + 1, & x > w
#            \\end{cases}
#
#        .. image:: ./_static/API/clock_driven/surrogate/PiecewiseLeakyReLU.*
#            :width: 100%
#
#        The function is used in [#yin2017algorithm]_ [#STBP]_ [#huh2018gradient]_ [#wu2019direct]_ [#STCA]_ [#roy2019scaling]_ [#LISNN]_ [#DECOLLE]_.
#        '''
#        super().__init__(spiking)
#        assert w > 0.
#        self.w = w
#        self.c = c
#        self.spiking = spiking
#        if spiking:
#            self.f = self.spiking_function
#        else:
#            self.f = self.primitive_function
#
#    def forward(self, x):
#        return self.f(x, self.w, self.c)
#
#    @staticmethod
#    def spiking_function(x: torch.Tensor, w, c):
#        return piecewise_leaky_relu.apply(x, w, c)
#
#    @staticmethod
#    def primitive_function(x: torch.Tensor, w, c):
#        mask0 = (x < -w).to(x)
#        mask1 = (x > w).to(x)
#        mask2 = torch.ones_like(x.data) - mask0 - mask1
#        if c == 0:
#            return mask2 * (x / (2 * w) + 1 / 2) + mask1
#        else:
#            cw = c * w
#            return mask0 * (c * x + cw) + mask1 * (c * x + (- cw + 1)) \
#                   + mask2 * (x / (2 * w) + 1 / 2)
#
#    def cuda_code(self, x: str, y: str, dtype='fp32'):
#        sg_name = 'sg_' + self._get_name()
#        w = str(self.w) + 'f'
#        w_inv = str(1. / self.w) + 'f'
#        c = str(self.c) + 'f'
#        code = f'''
#            {tab4_str}{self.cuda_code_start_comments()}
#        '''
#
#        if dtype == 'fp32':
#            code += f'''
#            {tab4_str}const float {sg_name}_x_abs = fabsf({x});
#            float {y};
#            if ({sg_name}_x_abs > {w})
#            {curly_bracket_l}
#                {y} = {c};
#            {curly_bracket_r}
#            else
#            {curly_bracket_l}
#                {y} = {w_inv};
#            {curly_bracket_r}
#            '''
#        elif dtype == 'fp16':
#            code += f'''
#            {tab4_str}const half2 {sg_name}_x_abs = __habs2({x});
#            {tab4_str}const half2 {sg_name}_x_abs_ge_w = __hge2({sg_name}_x_abs, __float2half2_rn({w}));
#            {tab4_str}half2 {y} = __hadd2(__hmul2(__float2half2_rn({c}),  {sg_name}_x_abs_ge_w), __hmul2(__hsub2(__float2half2_rn(1.0f), {sg_name}_x_abs_ge_w), __float2half2_rn({w_inv})));
#            '''
#        else:
#            raise NotImplementedError
#        code += f'''
#            {tab4_str}{self.cuda_code_end_comments()}
#        '''
#        return code
#
#    # plt.style.use(['science', 'muted', 'grid'])
#    # fig = plt.figure(dpi=200)
#    # x = torch.arange(-2.5, 2.5, 0.001)
#    # plt.plot(x.data, surrogate.heaviside(x), label='Heaviside', linestyle='-.')
#    # surrogate_function = surrogate.PiecewiseLeakyReLU(w=1, c=0.1, spiking=False)
#    # y = surrogate_function(x)
#    # plt.plot(x.data, y.data, label='Primitive, $w=1, c=0.1$')
#
#    # surrogate_function = surrogate.PiecewiseLeakyReLU(w=1, c=0.1, spiking=True)
#    # x.requires_grad_(True)
#    # y = surrogate_function(x)
#    # z = y.sum()
#    # z.backward()
#    # plt.plot(x.data, x.grad, label='Gradient, $w=1, c=0.1$')
#    # plt.xlim(-2, 2)
#    # plt.legend()
#    # plt.title('PiecewiseLeakyReLU surrogate function')
#    # plt.xlabel('Input')
#    # plt.ylabel('Output')
#    # plt.grid(linestyle='--')
#    # plt.show()
#
#
#class squarewave_fourier_series(torch.autograd.Function):
#    @staticmethod
#    def forward(ctx, x: torch.Tensor, n: int, T_period: float):
#        if x.requires_grad:
#            ctx.save_for_backward(x)
#            ctx.n = n
#            ctx.T_period = T_period
#        return heaviside(x)
#
#    @staticmethod
#    def backward(ctx, grad_output):
#        grad_x = 0.
#        x = ctx.saved_tensors[0]
#        w = math.pi * 2. / ctx.T_period
#        for i in range(1, ctx.n):
#            grad_x += torch.cos_((2 * i - 1.) * w * x)
#
#        grad_x *= 4. / ctx.T_period
#        grad_x *= grad_output
#
#        return grad_x, None, None
#
#
#class SquarewaveFourierSeries(MultiArgsSurrogateFunctionBase):
#    def __init__(self, n: int = 2, T_period: float = 8, spiking=True):
#        super().__init__(spiking)
#        assert isinstance(n, int) and T_period > 0.
#        self.n = n
#        self.T_period = T_period
#        self.spiking = spiking
#        if spiking:
#            self.f = self.spiking_function
#        else:
#            self.f = self.primitive_function
#
#    def forward(self, x):
#        return self.f(x, self.n, self.T_period)
#
#    @staticmethod
#    def spiking_function(x: torch.Tensor, w, c):
#        return squarewave_fourier_series.apply(x, w, c)
#
#    @staticmethod
#    def primitive_function(x: torch.Tensor, n: int, T_period: float):
#        w = math.pi * 2. / T_period
#        ret = torch.zeros_like(x.data)
#        for i in range(1, n):
#            c = (2 * i - 1.)
#            ret += torch.sin(c * w * x) / c
#
#        return 0.5 + 2. / math.pi * ret
#
#    def cuda_code(self, x: str, y: str, dtype='fp32'):
#        sg_name = 'sg_' + self._get_name()
#        w = str(self.w) + 'f'
#        w_inv = str(1. / self.w) + 'f'
#        c = str(self.c) + 'f'
#        code = f'''
#            {tab4_str}{self.cuda_code_start_comments()}
#        '''
#
#        if dtype == 'fp32':
#            raise NotImplementedError
#        elif dtype == 'fp16':
#            raise NotImplementedError
#        else:
#            raise NotImplementedError
#
#        code += f'''
#            {tab4_str}{self.cuda_code_end_comments()}
#        '''
#        return code
#
#    # import torch
#    # from spikingjelly.clock_driven import surrogate
#    # from matplotlib import pyplot as plt
#    # plt.style.use(['science', 'muted', 'grid'])
#    # fig = plt.figure(dpi=200, figsize=(6, 4))
#    # x = torch.arange(-2.5, 2.5, 0.001)
#    # plt.plot(x.data, surrogate.heaviside(x), label='Heaviside', linestyle='-.')
#    #
#    # c_list = []
#    # for n in [2, 4, 8]:
#    #     surrogate_function = surrogate.SquarewaveFourierSeries(n=n, T_period=8, spiking=False)
#    #     y = surrogate_function(x)
#    #     plt.plot(x.data, y.data, label=f'Primitive, $n={n}$')
#    #     c_list.append(plt.gca().lines[-1].get_color())
#    #
#    # plt.xlim(-2, 2)
#    # plt.legend()
#    # plt.title(f'SquarewaveFourierSeries surrogate function')
#    # plt.xlabel('Input')
#    # plt.ylabel('Output')
#    # # plt.grid(linestyle='--')
#    # plt.savefig('./docs/source/_static/API/clock_driven/surrogate/SquarewaveFourierSeries1.pdf')
#    # plt.savefig('./docs/source/_static/API/clock_driven/surrogate/SquarewaveFourierSeries1.svg')
#    # plt.clf()
#    # for i, n in enumerate([2, 4, 8]):
#    #     surrogate_function = surrogate.SquarewaveFourierSeries(n=n, T_period=8, spiking=True)
#    #     x = x.detach()
#    #     x.requires_grad_(True)
#    #     y = surrogate_function(x)
#    #     z = y.sum()
#    #     z.backward()
#    #     plt.plot(x.data, x.grad, label=f'Gradient, $n={n}$', c=c_list[i])
#    #     x.grad.zero_()
#    #
#    # plt.xlim(-2, 2)
#    # plt.legend()
#    # plt.title(f'SquarewaveFourierSeries surrogate function')
#    # plt.xlabel('Input')
#    # plt.ylabel('Output')
#    # # plt.grid(linestyle='--')
#    # plt.savefig('./docs/source/_static/API/clock_driven/surrogate/SquarewaveFourierSeries2.pdf')
#    # plt.savefig('./docs/source/_static/API/clock_driven/surrogate/SquarewaveFourierSeries2.svg')
#
#
#
#class quasi_clamp(torch.autograd.Function):
#    @staticmethod
#    def forward(ctx, x, vth):
#        if x.requires_grad:
#            ctx.save_for_backward(x)
#            ctx.vth = vth
#        return heaviside(x)
#
#    @staticmethod
#    def backward(ctx, grad_output):
#        grad_x = None
#        if ctx.needs_input_grad[0]:
#            x = ctx.saved_tensors[0]
#            mask1 = (x.abs() > ctx.vth)
#            mask_ = mask1.logical_not()
#            grad_x = grad_output * x.masked_fill(mask_, 1.).masked_fill(mask1, 0.)
#        return grad_x, None
#
#class QuasiClamp(SurrogateFunctionBase):
#    def __init__(self, alpha=1.0, spiking=True):
#        super().__init__(alpha, spiking)
#
#    @staticmethod
#    def spiking_function(x, alpha):
#        return quasi_clamp.apply(x, alpha)
#
#    @staticmethod
#    def primitive_function(x: torch.Tensor, alpha):
#        return torch.clamp(x, - alpha, alpha)
