import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.distributed
from torch import Tensor
from torch.nn.modules.pooling import _single, _pair, _triple

from lbmqt.conf import config
from lbmqt.nn.ops import linear, conv1d, conv2d, conv3d, conv_transpose1d, conv_transpose2d, conv_transpose3d


class QLinear(nn.Linear):
    def __init__(self, input_features, output_features, bias=True, module=None):
        super(QLinear, self).__init__(input_features, output_features, bias)
        delattr(self, 'weight')
        delattr(self, 'bias')
        setattr(self, 'weight', module.weight)
        setattr(self, 'bias', module.bias)

    def forward(self, input):
        if config.training:
            return linear.apply(input, self.weight, self.bias, self.scheme, self.optimizer)
        else:
            return linear.apply(input, self.weight, self.bias, self.scheme, None)


class QConv1d(nn.Conv1d):
    def __init__(self, in_channels, out_channels, kernel_size,
                 stride=1, padding=0, dilation=1, groups=1, bias=True, padding_mode='zeros', module=None):
        super(QConv1d, self).__init__(in_channels, out_channels, kernel_size,
                                      stride, padding, dilation, groups, bias, padding_mode)
        delattr(self, 'weight')
        delattr(self, 'bias')
        setattr(self, 'weight', module.weight)
        setattr(self, 'bias', module.bias)

    def forward(self, input):
        if config.training:
            if self.padding_mode != 'zeros':
                return conv1d.apply(F.pad(input, self._reversed_padding_repeated_twice, mode=self.padding_mode),
                                    self.weight, self.bias, self.stride,
                                    _single(0), self.dilation, self.groups, self.scheme, self.optimizer)
            return conv1d.apply(input, self.weight, self.bias, self.stride,
                                 self.padding, self.dilation, self.groups, self.scheme, self.optimizer)
        else:
            return conv1d.apply(input, self.weight, self.bias, self.stride,
                                 self.padding, self.dilation, self.groups, self.scheme, None)


class QConv2d(nn.Conv2d):
    def __init__(self, in_channels, out_channels, kernel_size,
                 stride=1, padding=0, dilation=1, groups=1, bias=True, padding_mode='zeros', module=None):
        super(QConv2d, self).__init__(in_channels, out_channels, kernel_size,
                                      stride, padding, dilation, groups, bias, padding_mode)
        delattr(self, 'weight')
        delattr(self, 'bias')
        setattr(self, 'weight', module.weight)
        setattr(self, 'bias', module.bias)

    def forward(self, input):
        if config.training:
            if self.padding_mode != 'zeros':
                return conv2d.apply(F.pad(input, self._reversed_padding_repeated_twice, mode=self.padding_mode),
                                    self.weight, self.bias, self.stride,
                                    _pair(0), self.dilation, self.groups, self.scheme, self.optimizer)
            return conv2d.apply(input, self.weight, self.bias, self.stride,
                                 self.padding, self.dilation, self.groups, self.scheme, self.optimizer)
        else:
            return conv2d.apply(input, self.weight, self.bias, self.stride,
                                 self.padding, self.dilation, self.groups, self.scheme, None)


class QConv3d(nn.Conv3d):
    def __init__(self, in_channels, out_channels, kernel_size,
                 stride=1, padding=0, dilation=1, groups=1, bias=True, padding_mode='zeros', module=None):
        super(QConv3d, self).__init__(in_channels, out_channels, kernel_size,
                                      stride, padding, dilation, groups, bias, padding_mode)
        delattr(self, 'weight')
        delattr(self, 'bias')
        setattr(self, 'weight', module.weight)
        setattr(self, 'bias', module.bias)

    def forward(self, input):
        if config.training:
            if self.padding_mode != 'zeros':
                return conv3d.apply(F.pad(input, self._reversed_padding_repeated_twice, mode=self.padding_mode),
                                    self.weight, self.bias, self.stride,
                                    _triple(0), self.dilation, self.groups, self.scheme, self.optimizer)
            return conv3d.apply(input, self.weight, self.bias, self.stride,
                                 self.padding, self.dilation, self.groups, self.scheme, self.optimizer)
        else:
            return conv3d.apply(input, self.weight, self.bias, self.stride,
                                 self.padding, self.dilation, self.groups, self.scheme, None)


class QConvTranspose1d(nn.ConvTranspose1d):
    def __init__(self, in_channels, out_channels, kernel_size,
                 stride=1, padding=0, output_padding=0, groups=1,
                 bias=True, dilation=1, padding_mode='zeros', module=None):
        super(QConvTranspose1d, self).__init__(in_channels, out_channels, kernel_size, stride,
                                               padding, output_padding, groups, bias, dilation, padding_mode)
        delattr(self, 'weight')
        delattr(self, 'bias')
        setattr(self, 'weight', module.weight)
        setattr(self, 'bias', module.bias)

    def forward(self, input, output_size=None):
        if config.training:
            if self.padding_mode != 'zeros':
                raise ValueError('Only `zeros` padding mode is supported for ConvTranspose1d')

            output_padding = self._output_padding(
                input, output_size, self.stride, self.padding, self.kernel_size, self.dilation)  # type: ignore

            return conv_transpose1d.apply(
                input, self.weight, self.bias, self.stride, self.padding,
                output_padding, self.groups, self.dilation, self.scheme, self.optimizer)
        else:
            return conv_transpose1d.apply(
                input, self.weight, self.bias, self.stride, self.padding,
                output_padding, self.groups, self.dilation, self.scheme, None)


class QConvTranspose2d(nn.ConvTranspose2d):
    def __init__(self, in_channels, out_channels, kernel_size,
                 stride=1, padding=0, output_padding=0, groups=1,
                 bias=True, dilation=1, padding_mode='zeros', module=None):
        super(QConvTranspose2d, self).__init__(in_channels, out_channels, kernel_size, stride,
                                               padding, output_padding, groups, bias, dilation, padding_mode)
        delattr(self, 'weight')
        delattr(self, 'bias')
        setattr(self, 'weight', module.weight)
        setattr(self, 'bias', module.bias)

    def forward(self, input, output_size=None):
        if config.training:
            if self.padding_mode != 'zeros':
                raise ValueError('Only `zeros` padding mode is supported for ConvTranspose2d')

            output_padding = self._output_padding(
                input, output_size, self.stride, self.padding, self.kernel_size, self.dilation)  # type: ignore

            return conv_transpose2d.apply(
                input, self.weight, self.bias, self.stride, self.padding,
                output_padding, self.groups, self.dilation, self.scheme, self.optimizer)
        else:
            return conv_transpose2d.apply(
                input, self.weight, self.bias, self.stride, self.padding,
                output_padding, self.groups, self.dilation, self.scheme, None)


class QConvTranspose3d(nn.ConvTranspose3d):
    def __init__(self, in_channels, out_channels, kernel_size,
                 stride=1, padding=0, output_padding=0, groups=1,
                 bias=True, dilation=1, padding_mode='zeros', module=None):
        super(QConvTranspose3d, self).__init__(in_channels, out_channels, kernel_size, stride,
                                               padding, output_padding, groups, bias, dilation, padding_mode)
        delattr(self, 'weight')
        delattr(self, 'bias')
        setattr(self, 'weight', module.weight)
        setattr(self, 'bias', module.bias)

    def forward(self, input, output_size=None):
        if config.training:
            if self.padding_mode != 'zeros':
                raise ValueError('Only `zeros` padding mode is supported for ConvTranspose3d')

            output_padding = self._output_padding(
                input, output_size, self.stride, self.padding, self.kernel_size, self.dilation)  # type: ignore

            return conv_transpose3d.apply(
                input, self.weight, self.bias, self.stride, self.padding,
                output_padding, self.groups, self.dilation, self.scheme, self.optimizer)
        else:
            return conv_transpose3d.apply(
                input, self.weight, self.bias, self.stride, self.padding,
                output_padding, self.groups, self.dilation, self.scheme, None)

