import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.cpp_extension import load_inline
from torch.cuda.amp import custom_fwd, custom_bwd
import logging
from . import tensor_cache

from torch import Tensor
from typing import Optional, Union
from torch.types import _int, _size
from torch.nn.modules.utils import _single, _pair, _triple

try:
    import cupy
except BaseException as e:
    logging.info(f'spikingjelly.activation_based.spike_op: {e}')
    cupy = None


try:
    logging.warning('spikingjelly.activation_based.spike_op: try to use `torch.utils.cpp_extension.load_inline` to load cudnn functions.')
    logging.warning(f'If it is hanging, pleast try to delete torch_extensions cache directory. (In most cases, the directory is {torch.utils.cpp_extension._get_build_directory("", False)}.)')
    cpp_wrapper = load_inline(
            name='cpp_wrapper',
            cpp_sources='using namespace at;',
            functions=[
                'cudnn_convolution_backward',
                'cudnn_convolution_backward_input',
                'cudnn_convolution_backward_weight'
            ],
            with_cuda=True
    )
except BaseException as e:
    logging.info(f'spikingjelly.activation_based.spike_op: {e}')
    cpp_wrapper = None

'''
aten/src/ATen/native/cudnn/ConvPlaceholders.cpp

at::Tensor cudnn_convolution(
    const at::Tensor& input, const at::Tensor& weight,
    IntArrayRef padding, IntArrayRef stride, IntArrayRef dilation,
    int64_t groups, bool benchmark, bool deterministic, bool allow_tf32)

There are two overloaded C++ methods `cudnn_convolution`. So, we need to use an alternative syntax to cast the overloaded function.
Refer to https://pybind11.readthedocs.io/en/stable/classes.html#overloaded-methods and https://github.com/pytorch/pytorch/issues/39518 for more details.
    
aten/src/ATen/native/cudnn/ConvShared.cpp

Tensor cudnn_convolution_forward(
    CheckedFrom c,
    const TensorArg& input, const TensorArg& weight,
    IntArrayRef padding, IntArrayRef stride, IntArrayRef dilation, int64_t groups,
    bool benchmark, bool deterministic, bool allow_tf32)

aten/src/ATen/native/cudnn/ConvPlaceholders.cpp

std::tuple<at::Tensor,at::Tensor> cudnn_convolution_backward(
    const at::Tensor& input, const at::Tensor& grad_output, const at::Tensor& weight,
    IntArrayRef padding, IntArrayRef stride, IntArrayRef dilation, int64_t groups,
    bool benchmark, bool deterministic, bool allow_tf32, std::array<bool,2> output_mask)
  
aten/src/ATen/native/cudnn/ConvShared.cpp

at::Tensor cudnn_convolution_backward_input(
    IntArrayRef input_size, const at::Tensor& grad_output, const at::Tensor& weight,
    IntArrayRef padding, IntArrayRef stride, IntArrayRef dilation, int64_t groups,
    bool benchmark, bool deterministic, bool allow_tf32)
    
aten/src/ATen/native/cudnn/ConvShared.cpp

at::Tensor cudnn_convolution_backward_weight(
    IntArrayRef weight_size, const at::Tensor& grad_output, const at::Tensor& input,
    IntArrayRef padding, IntArrayRef stride, IntArrayRef dilation, int64_t groups,
    bool benchmark, bool deterministic, bool allow_tf32)
'''

class spikeConvolution(torch.autograd.Function):
    # Pytorch only provides cudnn_convolution without bias.
    # Refer to https://github.com/pytorch/pytorch/issues/3823 for more details.
    @staticmethod
    @custom_fwd
    def forward(ctx, spike, weight, bias, stride, padding, dilation, groups):
        if ctx.needs_input_grad[0] or ctx.needs_input_grad[1] or ctx.needs_input_grad[2]:
            if ctx.needs_input_grad[1]:
                ctx.s_shape = spike.shape
                ctx.s_tk = tensor_cache.BOOL_TENSOR_CACHE.store_bool(spike)

            if ctx.needs_input_grad[0]:
                ctx.save_for_backward(weight)

            ctx.padding = padding
            ctx.stride = stride
            ctx.dilation = dilation
            ctx.groups = groups
            ctx.weight_shape = weight.shape

        if spike.dim() == 3:
            return F.conv1d(input=spike, weight=weight, bias=bias, stride=stride, padding=padding, dilation=dilation, groups=groups)
        elif spike.dim() == 4:
            return F.conv2d(input=spike, weight=weight, bias=bias, stride=stride, padding=padding, dilation=dilation, groups=groups)
        elif spike.dim() == 5:
            return F.conv3d(input=spike, weight=weight, bias=bias, stride=stride, padding=padding, dilation=dilation, groups=groups)



    @staticmethod
    @custom_bwd
    def backward(ctx, grad_output):
        grad_spike = None
        grad_weight = None
        grad_bias = None
        if ctx.needs_input_grad[0] and ctx.needs_input_grad[1]:
            weight = ctx.saved_tensors[0]
            spike = tensor_cache.BOOL_TENSOR_CACHE.get_float(ctx.s_tk, ctx.s_shape)
            weight = weight.to(grad_output.dtype)
            grad_spike, grad_weight = cpp_wrapper.cudnn_convolution_backward(spike, grad_output, weight, ctx.padding,
                                                                               ctx.stride, ctx.dilation, ctx.groups,
                                                                               torch.backends.cudnn.benchmark,
                                                                               torch.backends.cudnn.deterministic,
                                                                               torch.backends.cudnn.allow_tf32, (
                                                                               True,
                                                                               True))

        elif not ctx.needs_input_grad[0] and ctx.needs_input_grad[1]:
            spike = tensor_cache.BOOL_TENSOR_CACHE.get_float(ctx.s_tk, ctx.s_shape)
            grad_weight = cpp_wrapper.cudnn_convolution_backward_weight(ctx.weight_shape, grad_output, spike, ctx.padding,
                                                                               ctx.stride, ctx.dilation, ctx.groups,
                                                                               torch.backends.cudnn.benchmark,
                                                                               torch.backends.cudnn.deterministic,
                                                                               torch.backends.cudnn.allow_tf32)

        elif ctx.needs_input_grad[0] and not ctx.needs_input_grad[1]:
            weight = ctx.saved_tensors[0]
            weight = weight.to(grad_output.dtype)
            grad_spike = cpp_wrapper.cudnn_convolution_backward_input(ctx.spike_shape, grad_output, weight, ctx.padding,
                                                                               ctx.stride, ctx.dilation, ctx.groups,
                                                                               torch.backends.cudnn.benchmark,
                                                                               torch.backends.cudnn.deterministic,
                                                                               torch.backends.cudnn.allow_tf32)

        if ctx.needs_input_grad[2]:
            # grad_output.shape = [N, C, *]
            out_channels = grad_output.shape[1]
            grad_bias = grad_output.transpose(0, 1).reshape(out_channels, -1).sum(1)
        return grad_spike, grad_weight, grad_bias, None, None, None, None

class spikeLinear(torch.autograd.Function):
    @staticmethod
    @custom_fwd
    def forward(ctx, spike, weight, bias=None):
        # spike.shape = [N, *, in_features]
        # weight.shape = [out_features, in_features]
        # bias.shape = [out_features]
        if ctx.needs_input_grad[0] or ctx.needs_input_grad[1] or ctx.needs_input_grad[2]:
            if ctx.needs_input_grad[1]:
                ctx.s_shape = spike.shape
                ctx.s_tk = tensor_cache.BOOL_TENSOR_CACHE.store_bool(spike)
            if ctx.needs_input_grad[0]:
                ctx.save_for_backward(weight)
        return F.linear(spike, weight, bias)

    @staticmethod
    @custom_bwd
    def backward(ctx, grad_output):
        # grad_output.shape = [N, *, out_features]
        if ctx.needs_input_grad[1]:
            weight = ctx.saved_tensors[0]
        if ctx.needs_input_grad[0]:
            spike = tensor_cache.BOOL_TENSOR_CACHE.get_float(ctx.s_tk, ctx.s_shape)

        grad_spike = grad_weight = grad_bias = None

        if ctx.needs_input_grad[0]:
            grad_spike = F.linear(grad_output, weight.t(), bias=None)
        if ctx.needs_input_grad[1]:
            in_features = spike.shape[-1]
            out_features = grad_output.shape[-1]
            # grad_output.reshape(-1, out_features).t().shape = [out_features, N*]
            # spike.reshape(-1, in_features).shape = [N*, in_features]
            grad_weight = torch.mm(grad_output.reshape(-1, out_features).t(), spike.reshape(-1, in_features).to(grad_output.dtype))
        if ctx.needs_input_grad[2]:
            out_features = grad_output.shape[-1]
            grad_bias = grad_output.reshape(-1, out_features).sum(0)
        return grad_spike, grad_weight, grad_bias

def spike_linear(spike: Tensor, weight: Tensor, bias: Optional[Tensor] = None) -> Tensor:
    """
    * :ref:`API in English <spike_linear-en>`

    .. _spike_linear-cn:

    :class:`torch.nn.functional.linear` 在输入为脉冲时的特例。

    .. note::

        在CUDA设备上训练时拥有比 :class:`torch.nn.functional.linear` 更低的显存消耗。

    .. warning::

        `spike` 中的任何元素都必须为0或1。

    * :ref:`中文API <spike_linear-cn>`

    .. _spike_linear-en:

    A specific case of :class:`torch.nn.functional.linear` with inputs are spikes.

    .. admonition:: Note
        :class: note

        This function has less memory consumption than :class:`torch.nn.functional.linear` when training on CUDA devices.

    .. admonition:: Warning
        :class: warning

        Any element in `spike` must be 0 or 1.
    """
    if spike.get_device() < 0:
        return F.linear(spike, weight, bias)
    else:
        return spikeLinear.apply(spike, weight, bias)

def spike_conv1d(spike: Tensor, weight: Tensor, bias: Tensor=None, stride: Union[_int, _size]=1, padding: str="valid", dilation: Union[_int, _size]=1, groups: _int=1) -> Tensor:
    """
    * :ref:`API in English <spike_conv1d-en>`

    .. _spike_conv1d-cn:

    :class:`torch.nn.functional.conv1d` 在输入为脉冲时的特例。

    .. note::

        在CUDA设备上训练时拥有比 :class:`torch.nn.functional.conv1d` 更低的显存消耗。

    .. warning::

        `spike` 中的任何元素都必须为0或1。

    * :ref:`中文API <spike_conv1d-cn>`

    .. _spike_conv1d-en:

    A specific case of :class:`torch.nn.functional.conv1d` with inputs are spikes.

    .. admonition:: Note
        :class: note

        This function has less memory consumption than :class:`torch.nn.functional.conv1d` when training on CUDA devices.

    .. admonition:: Warning
        :class: warning

        Any element in `spike` must be 0 or 1.
    """
    if spike.get_device() < 0:
        return F.conv1d(spike, weight, bias, stride, padding, dilation, groups)
    else:
        return spikeConvolution.apply(spike, weight, bias, stride, padding, dilation, groups)

def spike_conv2d(spike: Tensor, weight: Tensor, bias: Optional[Tensor]=None, stride: Union[_int, _size]=1, padding: str="valid", dilation: Union[_int, _size]=1, groups: _int=1) -> Tensor:
    """
    * :ref:`API in English <spike_conv2d-en>`

    .. _spike_conv2d-cn:

    :class:`torch.nn.functional.conv2d` 在输入为脉冲时的特例。

    .. note::

        在CUDA设备上训练时拥有比 :class:`torch.nn.functional.conv2d` 更低的显存消耗。

    .. warning::

        `spike` 中的任何元素都必须为0或1。

    * :ref:`中文API <spike_conv2d-cn>`

    .. _spike_conv2d-en:

    A specific case of :class:`torch.nn.functional.conv2d` with inputs are spikes.

    .. admonition:: Note
        :class: note

        This function has less memory consumption than :class:`torch.nn.functional.conv2d` when training on CUDA devices.

    .. admonition:: Warning
        :class: warning

        Any element in `spike` must be 0 or 1.
    """
    if spike.get_device() < 0:
        return F.conv2d(spike, weight, bias, stride, padding, dilation, groups)
    else:
        return spikeConvolution.apply(spike, weight, bias, stride, padding, dilation, groups)

def spike_conv3d(spike: Tensor, weight: Tensor, bias: Optional[Tensor]=None, stride: Union[_int, _size]=1, padding: str="valid", dilation: Union[_int, _size]=1, groups: _int=1) -> Tensor:
    """
    * :ref:`API in English <spike_conv3d-en>`

    .. _spike_conv3d-cn:

    :class:`torch.nn.functional.conv3d` 在输入为脉冲时的特例。

    .. note::

        在CUDA设备上训练时拥有比 :class:`torch.nn.functional.conv3d` 更低的显存消耗。

    .. warning::

        `spike` 中的任何元素都必须为0或1。

    * :ref:`中文API <spike_conv3d-cn>`

    .. _spike_conv3d-en:

    A specific case of :class:`torch.nn.functional.conv3d` with inputs are spikes.

    .. admonition:: Note
        :class: note

        This function has less memory consumption than :class:`torch.nn.functional.conv3d` when training on CUDA devices.

    .. admonition:: Warning
        :class: warning

        Any element in `spike` must be 0 or 1.
    """
    if spike.get_device() < 0:
        return F.conv3d(spike, weight, bias, stride, padding, dilation, groups)
    else:
        return spikeConvolution.apply(spike, weight, bias, stride, padding, dilation, groups)


class SpikeLinear(nn.Linear):
    """
    * :ref:`API in English <SpikeLinear-en>`

    .. _SpikeLinear-cn:

    :class:`torch.nn.Linear` 在输入为脉冲时的特例。

    .. note::

        在CUDA设备上运行时拥有比 :class:`torch.nn.Linear` 更低的显存消耗。

    .. warning::

        `spike` 中的任何元素都必须为0或1。

    * :ref:`中文API <SpikeLinear-cn>`

    .. _SpikeLinear-en:

    A specific case of :class:`torch.nn.Linear` with inputs are spikes.

    .. admonition:: Note
        :class: note

        This function has less memory consumption than :class:`torch.nn.Linear` when training on CUDA devices.

    .. admonition:: Warning
        :class: warning

        Any element in `spike` must be 0 or 1.
    """

    def forward(self, spike: Tensor) -> Tensor:
        return spike_linear(spike, self.weight, self.bias)


class SpikeConv1d(nn.Conv1d):
    """
    * :ref:`API in English <SpikeConv1d-en>`

    .. _SpikeConv1d-cn:

    :class:`torch.nn.Conv1d` 在输入为脉冲时的特例。

    .. note::

        在CUDA设备上运行时拥有比 :class:`torch.nn.Conv1d` 更低的显存消耗。

    .. warning::

        `spike` 中的任何元素都必须为0或1。

    * :ref:`中文API <SpikeConv1d-cn>`

    .. _SpikeConv1d-en:

    A specific case of :class:`torch.nn.Conv1d` with inputs are spikes.

    .. admonition:: Note
        :class: note

        This function has less memory consumption than :class:`torch.nn.Conv1d` when training on CUDA devices.

    .. admonition:: Warning
        :class: warning

        Any element in `spike` must be 0 or 1.
    """

    def _conv_forward(self, spike: Tensor, weight: Tensor, bias: Optional[Tensor]):
        if self.padding_mode != 'zeros':
            return spike_conv1d(F.pad(spike, self._reversed_padding_repeated_twice, mode=self.padding_mode),
                                           weight, bias, self.stride,
                                           _single(0), self.dilation, self.groups)
        return spike_conv1d(spike, weight, bias, self.stride,
                                       self.padding, self.dilation, self.groups)


class SpikeConv2d(nn.Conv2d):
    """
    * :ref:`API in English <SpikeConv2d-en>`

    .. _SpikeConv2d-cn:

    :class:`torch.nn.Conv2d` 在输入为脉冲时的特例。

    .. note::

        在CUDA设备上运行时拥有比 :class:`torch.nn.Conv2d` 更低的显存消耗。

    .. warning::

        `spike` 中的任何元素都必须为0或1。

    * :ref:`中文API <SpikeConv2d-cn>`

    .. _SpikeConv2d-en:

    A specific case of :class:`torch.nn.Conv2d` with inputs are spikes.

    .. admonition:: Note
        :class: note

        This function has less memory consumption than :class:`torch.nn.Conv2d` when training on CUDA devices.

    .. admonition:: Warning
        :class: warning

        Any element in `spike` must be 0 or 1.
    """

    def _conv_forward(self, spike: Tensor, weight: Tensor, bias: Optional[Tensor]):
        if self.padding_mode != 'zeros':
            return spike_conv2d(F.pad(spike, self._reversed_padding_repeated_twice, mode=self.padding_mode),
                                           weight, bias, self.stride,
                                           _pair(0), self.dilation, self.groups)
        return spike_conv2d(spike, weight, bias, self.stride,
                                       self.padding, self.dilation, self.groups)


class SpikeConv3d(nn.Conv3d):
    """
    * :ref:`API in English <SpikeConv3d-en>`

    .. _SpikeConv3d-cn:

    :class:`torch.nn.Conv3d` 在输入为脉冲时的特例。

    .. note::

        在CUDA设备上运行时拥有比 :class:`torch.nn.Conv3d` 更低的显存消耗。

    .. warning::

        `spike` 中的任何元素都必须为0或1。

    * :ref:`中文API <SpikeConv3d-cn>`

    .. _SpikeConv3d-en:

    A specific case of :class:`torch.nn.Conv3d` with inputs are spikes.

    .. admonition:: Note
        :class: note

        This function has less memory consumption than :class:`torch.nn.Conv3d` when training on CUDA devices.

    .. admonition:: Warning
        :class: warning

        Any element in `spike` must be 0 or 1.
    """

    def _conv_forward(self, spike: Tensor, weight: Tensor, bias: Optional[Tensor]):
        if self.padding_mode != "zeros":
            return spike_conv3d(
                F.pad(
                    spike, self._reversed_padding_repeated_twice, mode=self.padding_mode
                ),
                weight,
                bias,
                self.stride,
                _triple(0),
                self.dilation,
                self.groups,
            )
        return spike_conv3d(
            spike, weight, bias, self.stride, self.padding, self.dilation, self.groups
        )
