import torch.nn as nn
import torch.utils.model_zoo as model_zoo
import torch
import torch.nn.functional as F
from torch.nn.modules.conv import _ConvNd
from torch.nn.common_types import _size_1_t, _size_2_t, _size_3_t
from torch.nn.modules.utils import _single, _pair, _triple, _reverse_repeat_tuple
#POOL=nn.MaxPool2d(kernel_size=2, stride=2)
POOL=nn.AvgPool2d(kernel_size=2, stride=2)
class Linear(nn.Module):
    r"""Applies a linear transformation to the incoming data: :math:`y = xA^T + b`

    This module supports :ref:`TensorFloat32<tf32_on_ampere>`.

    Args:
        in_features: size of each input sample
        out_features: size of each output sample
        bias: If set to ``False``, the layer will not learn an additive bias.
            Default: ``True``

    Shape:
        - Input: :math:`(N, *, H_{in})` where :math:`*` means any number of
          additional dimensions and :math:`H_{in} = \text{in\_features}`
        - Output: :math:`(N, *, H_{out})` where all but the last dimension
          are the same shape as the input and :math:`H_{out} = \text{out\_features}`.

    Attributes:
        weight: the learnable weights of the module of shape
            :math:`(\text{out\_features}, \text{in\_features})`. The values are
            initialized from :math:`\mathcal{U}(-\sqrt{k}, \sqrt{k})`, where
            :math:`k = \frac{1}{\text{in\_features}}`
        bias:   the learnable bias of the module of shape :math:`(\text{out\_features})`.
                If :attr:`bias` is ``True``, the values are initialized from
                :math:`\mathcal{U}(-\sqrt{k}, \sqrt{k})` where
                :math:`k = \frac{1}{\text{in\_features}}`

    Examples::

        >>> m = nn.Linear(20, 30)
        >>> input = torch.randn(128, 20)
        >>> output = m(input)
        >>> print(output.size())
        torch.Size([128, 30])
    """
    __constants__ = ['in_features', 'out_features']
    in_features: int
    out_features: int
    weight: torch.Tensor

    def __init__(self, in_features: int, out_features: int, bias: bool = True) -> None:
        super(Linear, self).__init__()
        self.in_features = in_features
        self.out_features = out_features
        self.weight = nn.Parameter(torch.Tensor(out_features, in_features))
        if bias:
            self.bias = nn.Parameter(torch.Tensor(out_features))
        else:
            self.register_parameter('bias', None)
        #self.bn=nn.BatchNorm1d(in_features,affine=False)
        self.reset_parameters()

    def reset_parameters(self) -> None:
        import math
        nn.init.kaiming_uniform_(self.weight, a=math.sqrt(5))
        if self.bias is not None:
            fan_in, _ = nn.init._calculate_fan_in_and_fan_out(self.weight)
            bound = 1 / math.sqrt(fan_in)
            nn.init.uniform_(self.bias, -bound, bound)

    def forward(self, input: torch.Tensor) -> torch.Tensor:
        #if self.training:
        #    input=self.bn(input)+torch.mean(input,0,keepdim=True)
        #else:
        #    input=self.bn(input)+self.bn.running_mean
        return F.linear(input, torch.sigmoid(self.weight), self.bias)

    def extra_repr(self) -> str:
        return 'in_features={}, out_features={}, bias={}'.format(
            self.in_features, self.out_features, self.bias is not None
        )

class Conv2d(_ConvNd):
    r"""Applies a 2D convolution over an input signal composed of several input
    planes.

    In the simplest case, the output value of the layer with input size
    :math:`(N, C_{\text{in}}, H, W)` and output :math:`(N, C_{\text{out}}, H_{\text{out}}, W_{\text{out}})`
    can be precisely described as:

    .. math::
        \text{out}(N_i, C_{\text{out}_j}) = \text{bias}(C_{\text{out}_j}) +
        \sum_{k = 0}^{C_{\text{in}} - 1} \text{weight}(C_{\text{out}_j}, k) \star \text{input}(N_i, k)


    where :math:`\star` is the valid 2D `cross-correlation`_ operator,
    :math:`N` is a batch size, :math:`C` denotes a number of channels,
    :math:`H` is a height of input planes in pixels, and :math:`W` is
    width in pixels.

    This module supports :ref:`TensorFloat32<tf32_on_ampere>`.

    * :attr:`stride` controls the stride for the cross-correlation, a single
      number or a tuple.

    * :attr:`padding` controls the amount of implicit zero-paddings on both
      sides for :attr:`padding` number of points for each dimension.

    * :attr:`dilation` controls the spacing between the kernel points; also
      known as the à trous algorithm. It is harder to describe, but this `link`_
      has a nice visualization of what :attr:`dilation` does.

    * :attr:`groups` controls the connections between inputs and outputs.
      :attr:`in_channels` and :attr:`out_channels` must both be divisible by
      :attr:`groups`. For example,

        * At groups=1, all inputs are convolved to all outputs.
        * At groups=2, the operation becomes equivalent to having two conv
          layers side by side, each seeing half the input channels,
          and producing half the output channels, and both subsequently
          concatenated.
        * At groups= :attr:`in_channels`, each input channel is convolved with
          its own set of filters, of size:
          :math:`\left\lfloor\frac{out\_channels}{in\_channels}\right\rfloor`.

    The parameters :attr:`kernel_size`, :attr:`stride`, :attr:`padding`, :attr:`dilation` can either be:

        - a single ``int`` -- in which case the same value is used for the height and width dimension
        - a ``tuple`` of two ints -- in which case, the first `int` is used for the height dimension,
          and the second `int` for the width dimension

    Note:

         Depending of the size of your kernel, several (of the last)
         columns of the input might be lost, because it is a valid `cross-correlation`_,
         and not a full `cross-correlation`_.
         It is up to the user to add proper padding.

    Note:

        When `groups == in_channels` and `out_channels == K * in_channels`,
        where `K` is a positive integer, this operation is also termed in
        literature as depthwise convolution.

        In other words, for an input of size :math:`(N, C_{in}, H_{in}, W_{in})`,
        a depthwise convolution with a depthwise multiplier `K`, can be constructed by arguments
        :math:`(in\_channels=C_{in}, out\_channels=C_{in} \times K, ..., groups=C_{in})`.

    Note:
        In some circumstances when using the CUDA backend with CuDNN, this operator
        may select a nondeterministic algorithm to increase performance. If this is
        undesirable, you can try to make the operation deterministic (potentially at
        a performance cost) by setting ``torch.backends.cudnn.deterministic =
        True``.
        Please see the notes on :doc:`/notes/randomness` for background.


    Args:
        in_channels (int): Number of channels in the input image
        out_channels (int): Number of channels produced by the convolution
        kernel_size (int or tuple): Size of the convolving kernel
        stride (int or tuple, optional): Stride of the convolution. Default: 1
        padding (int or tuple, optional): Zero-padding added to both sides of
            the input. Default: 0
        padding_mode (string, optional): ``'zeros'``, ``'reflect'``,
            ``'replicate'`` or ``'circular'``. Default: ``'zeros'``
        dilation (int or tuple, optional): Spacing between kernel elements. Default: 1
        groups (int, optional): Number of blocked connections from input
            channels to output channels. Default: 1
        bias (bool, optional): If ``True``, adds a learnable bias to the
            output. Default: ``True``

    Shape:
        - Input: :math:`(N, C_{in}, H_{in}, W_{in})`
        - Output: :math:`(N, C_{out}, H_{out}, W_{out})` where

          .. math::
              H_{out} = \left\lfloor\frac{H_{in}  + 2 \times \text{padding}[0] - \text{dilation}[0]
                        \times (\text{kernel\_size}[0] - 1) - 1}{\text{stride}[0]} + 1\right\rfloor

          .. math::
              W_{out} = \left\lfloor\frac{W_{in}  + 2 \times \text{padding}[1] - \text{dilation}[1]
                        \times (\text{kernel\_size}[1] - 1) - 1}{\text{stride}[1]} + 1\right\rfloor

    Attributes:
        weight (Tensor): the learnable weights of the module of shape
            :math:`(\text{out\_channels}, \frac{\text{in\_channels}}{\text{groups}},`
            :math:`\text{kernel\_size[0]}, \text{kernel\_size[1]})`.
            The values of these weights are sampled from
            :math:`\mathcal{U}(-\sqrt{k}, \sqrt{k})` where
            :math:`k = \frac{groups}{C_\text{in} * \prod_{i=0}^{1}\text{kernel\_size}[i]}`
        bias (Tensor):   the learnable bias of the module of shape
            (out_channels). If :attr:`bias` is ``True``,
            then the values of these weights are
            sampled from :math:`\mathcal{U}(-\sqrt{k}, \sqrt{k})` where
            :math:`k = \frac{groups}{C_\text{in} * \prod_{i=0}^{1}\text{kernel\_size}[i]}`

    Examples:

        >>> # With square kernels and equal stride
        >>> m = nn.Conv2d(16, 33, 3, stride=2)
        >>> # non-square kernels and unequal stride and with padding
        >>> m = nn.Conv2d(16, 33, (3, 5), stride=(2, 1), padding=(4, 2))
        >>> # non-square kernels and unequal stride and with padding and dilation
        >>> m = nn.Conv2d(16, 33, (3, 5), stride=(2, 1), padding=(4, 2), dilation=(3, 1))
        >>> input = torch.randn(20, 16, 50, 100)
        >>> output = m(input)

    .. _cross-correlation:
        https://en.wikipedia.org/wiki/Cross-correlation

    .. _link:
        https://github.com/vdumoulin/conv_arithmetic/blob/master/README.md
    """
    def __init__(
        self,
        in_channels: int,
        out_channels: int,
        kernel_size: _size_2_t=1,
        stride: _size_2_t = 1,
        padding: _size_2_t = 0,
        dilation: _size_2_t = 1,
        groups: int = 1,
        bias: bool = False,
        padding_mode: str = 'zeros'  # TODO: refine this type
    ):
        kernel_size = _pair(kernel_size)
        stride = _pair(stride)
        padding = _pair(padding)
        dilation = _pair(dilation)
        super(Conv2d, self).__init__(
            in_channels, out_channels, kernel_size, stride, padding, dilation,
            False, _pair(0), groups, bias, padding_mode)

    def _conv_forward(self, input, weight):
        if self.padding_mode != 'zeros':
            return F.conv2d(F.pad(input, self._reversed_padding_repeated_twice, mode=self.padding_mode),
                            weight, self.bias, self.stride,
                            _pair(0), self.dilation, self.groups)
        return F.conv2d(input, weight, self.bias, self.stride,
                        self.padding, self.dilation, self.groups)

    def forward(self, input: torch.Tensor) -> torch.Tensor:
        return self._conv_forward(input, torch.sigmoid(self.weight))
class cane_VGG(nn.Module):

    def __init__(self, features, num_classes=1000, init_weights=True,cane=False,alpha=1):
        super(cane_VGG, self).__init__()
        self.cane=cane
        self.alpha=alpha
        self.features = features
        self.avgpool = nn.AdaptiveAvgPool2d((1, 1))
        self.classifier=nn.Linear(512, num_classes,bias=False)
        #nn.init.normal_(self.classifier.weight, 0, 0.01)
        #nn.init.constant_(self.classifier.bias, 0)
        if init_weights:
            self._initialize_weights()

    def forward(self, x):
        x,y = self.features(x)

        if self.cane:
            x=y
        else:
            x = self.avgpool(x)
            x = x.view(x.size(0), -1)
            x = self.classifier(x)
        return x

    def _initialize_weights(self):
        for m in self.modules():
            if isinstance(m, nn.Conv2d):
                nn.init.kaiming_normal_(m.weight, mode='fan_out', nonlinearity='relu')
                if m.bias is not None:
                    nn.init.constant_(m.bias, 0)
            elif isinstance(m, nn.BatchNorm2d):
                nn.init.constant_(m.weight, 1)
                nn.init.constant_(m.bias, 0)



class cane_conv(nn.Module):
    def __init__(self,in_dim,out_dim,kernel_size=3,padding=1,num_classes=10):
        super(cane_conv,self).__init__()
        self.conv=nn.Conv2d(in_dim,out_dim,kernel_size,padding,bias=False)
        self.bn=nn.BatchNorm2d(out_dim)
        self.pool=nn.AdaptiveAvgPool2d((1,1))
        self.relu=nn.ReLU(inplace=True)
        self.a = Linear(out_dim, num_classes, bias=False)
    def forward(self,x):
        x=self.relu(self.bn(self.conv(x)))
        h = self.pool(x).view(x.size(0), -1)
        #y=torch.sum(self.pool(x).view(x.size(0),-1),1,keepdim=True)
        y = self.a(h)
        return x,y
class cane_conv_no_bn(nn.Module):
    def __init__(self,in_dim,out_dim,kernel_size=3,padding=1,num_classes=10,bias=False):
        super(cane_conv_no_bn,self).__init__()
        self.conv=nn.Conv2d(in_dim,out_dim,kernel_size,padding,bias=bias)
        self.pool=nn.AdaptiveAvgPool2d((1,1))
        self.relu=nn.ReLU(inplace=True)
        #self.a=nn.Parameter(torch.ones(1,num_classes))
        self.a=Linear(out_dim, num_classes,bias=False)
        #self.a=Conv2d(out_dim,num_classes)
        #self.b=nn.Parameter(torch.randn(1,num_classes,out_dim))
        #nn.init.normal_(self.a.weight, 1, 0.1)
    def forward(self,x):
        x=self.relu(self.conv(x))
        h=self.pool(x).view(x.size(0),-1)
        #y=torch.sum(h,1,keepdim=True)
        #y=y*self.a#(torch.tanh(self.a))#torch.sigmoid(self.a)
        y=self.a(h)
        return x,y
class cane_feature_imagenet(nn.Module):
    def __init__(self,width=64,num_classes=1000):
        super(cane_feature_imagenet,self).__init__()
        self.main=nn.ModuleList([])
        self.main.append(cane_conv(3,width,num_classes=num_classes))
        self.main.append(POOL)
        self.main.append(cane_conv(width,2*width,num_classes=num_classes))
        self.main.append(POOL)
        self.main.append(cane_conv(2*width,4*width,num_classes=num_classes))
        self.main.append(cane_conv(4 * width, 4 * width,num_classes=num_classes))
        self.main.append(POOL)
        self.main.append(cane_conv(4 * width, 8 * width,num_classes=num_classes))
        self.main.append(cane_conv(8 * width, 8 * width,num_classes=num_classes))
        self.main.append(POOL)
        self.main.append(cane_conv(8 * width, 8 * width,num_classes=num_classes))
        self.main.append(cane_conv(8 * width, 8 * width,num_classes=num_classes))
        self.main.append(POOL)
    def forward(self,x):
        y = 0
        for layer in self.main:
            out = layer(x)
            if isinstance(out, tuple):
                x = out[0]
                y += out[1]
            else:
                x = out
        return x, y

class cane_feature_imagenet_no_bn(nn.Module):
    def __init__(self,width=64,num_classes=1000,bias=False):
        super(cane_feature_imagenet_no_bn,self).__init__()
        self.main=nn.ModuleList([])
        self.main.append(cane_conv_no_bn(3,width,num_classes=num_classes))
        self.main.append(POOL)
        self.main.append(cane_conv_no_bn(width,2*width,num_classes=num_classes))
        self.main.append(POOL)
        self.main.append(cane_conv_no_bn(2*width,4*width,num_classes=num_classes))
        self.main.append(cane_conv_no_bn(4 * width, 4 * width,num_classes=num_classes))
        self.main.append(POOL)
        self.main.append(cane_conv_no_bn(4 * width, 8 * width,num_classes=num_classes))
        self.main.append(cane_conv_no_bn(8 * width, 8 * width,num_classes=num_classes))
        self.main.append(POOL)
        self.main.append(cane_conv_no_bn(8 * width, 8 * width,num_classes=num_classes))
        self.main.append(cane_conv_no_bn(8 * width, 8 * width,num_classes=num_classes))
        self.main.append(POOL)
    def forward(self,x):
        y = 0
        for layer in self.main:
            out = layer(x)
            if isinstance(out, tuple):
                x = out[0]
                y += out[1]
            else:
                x = out
        return x, y
class cane_feature_cifar(nn.Module):
    def __init__(self,width=64,num_classes=10):
        super(cane_feature_cifar,self).__init__()
        self.main=nn.ModuleList([])
        self.main.append(cane_conv(3,width,num_classes=num_classes))
        self.main.append(cane_conv(width,2*width,num_classes=num_classes))
        self.main.append(nn.AvgPool2d(kernel_size=2, stride=2))
        self.main.append(cane_conv(2*width,4*width,num_classes=num_classes))
        self.main.append(cane_conv(4 * width, 4 * width,num_classes=num_classes))
        self.main.append(nn.AvgPool2d(kernel_size=2, stride=2))
        self.main.append(cane_conv(4 * width, 8 * width,num_classes=num_classes))
        self.main.append(cane_conv(8 * width, 8 * width,num_classes=num_classes))
    def forward(self,x):
        y=0
        for layer in self.main:
            out=layer(x)
            if isinstance(out, tuple):
                x=out[0]
                y+=out[1]
            else:
                x=out
        return x,y
class cane_feature_cifar_no_bn(nn.Module):
    def __init__(self,width=64,num_classes=10,bias=False):
        super(cane_feature_cifar_no_bn,self).__init__()
        self.main=nn.ModuleList([])
        self.main.append(cane_conv_no_bn(3,width,num_classes=num_classes,bias=bias))
        self.main.append(cane_conv_no_bn(width,2*width,num_classes=num_classes,bias=bias))
        self.main.append(nn.AvgPool2d(kernel_size=2, stride=2))
        self.main.append(cane_conv_no_bn(2*width,4*width,num_classes=num_classes,bias=bias))
        self.main.append(cane_conv_no_bn(4 * width, 4 * width,num_classes=num_classes,bias=bias))
        self.main.append(nn.AvgPool2d(kernel_size=2, stride=2))
        self.main.append(cane_conv_no_bn(4 * width, 8 * width,num_classes=num_classes,bias=bias))
        self.main.append(cane_conv_no_bn(8 * width, 8 * width,num_classes=num_classes,bias=bias))
    def forward(self,x):
        y = 0
        for layer in self.main:
            out = layer(x)
            if isinstance(out, tuple):
                x = out[0]
                y += out[1]
            else:
                x = out
        return x, y
def canet11_imagenet(pretrained=False, in_channels=3,num_classes=1000,**kwargs):
    """VGG 11-layer model (configuration "A") with batch normalization

    Args:
        pretrained (bool): If True, returns a model pre-trained on ImageNet
    """
    if pretrained:
        kwargs['init_weights'] = False
    model = cane_VGG(cane_feature_imagenet(num_classes=num_classes),num_classes=num_classes, **kwargs)
    return model
def canet11_imagenet_no_bn(pretrained=False, in_channels=3,num_classes=1000,**kwargs):
    """VGG 11-layer model (configuration "A") with batch normalization

    Args:
        pretrained (bool): If True, returns a model pre-trained on ImageNet
    """
    if pretrained:
        kwargs['init_weights'] = False
    model = cane_VGG(cane_feature_imagenet_no_bn(num_classes=num_classes),num_classes=num_classes, **kwargs)
    return model
def canet9_cifar(pretrained=False, in_channels=3,num_classes=10,**kwargs):
    """VGG 11-layer model (configuration "A") with batch normalization

    Args:
        pretrained (bool): If True, returns a model pre-trained on ImageNet
    """
    if pretrained:
        kwargs['init_weights'] = False
    model = cane_VGG(cane_feature_cifar(num_classes=num_classes),num_classes=num_classes, **kwargs)
    return model

def canet9_cifar_no_bn(pretrained=False, in_channels=3,num_classes=10,bias=False,**kwargs):
    """VGG 11-layer model (configuration "A") with batch normalization

    Args:
        pretrained (bool): If True, returns a model pre-trained on ImageNet
    """
    if pretrained:
        kwargs['init_weights'] = False
    model = cane_VGG(cane_feature_cifar_no_bn(num_classes=num_classes,bias=bias),num_classes=num_classes, **kwargs)
    return model
