import torch
from torch import nn


class View(nn.Module):
    def __init__(self, shape):
        super().__init__()
        self.shape = shape

    def forward(self, x):
        return x.view(*self.shape)


class SepConvReLUBN(nn.Sequential):
    def __init__(
        self, C_in, C_out, kernel_size=3, padding=1, dilation=1, affine=True, stride=1
    ):
        super().__init__(
            nn.ReLU(inplace=False),
            nn.Conv2d(
                C_in,
                C_in,
                kernel_size,
                padding=padding,
                groups=C_in,
                dilation=dilation,
                bias=False,
                stride=stride,
            ),
            nn.Conv2d(C_in, C_out, 1, bias=False),
            nn.BatchNorm2d(C_out, affine=affine),
        )


class ConvReLUBN(nn.Sequential):
    def __init__(
        self,
        C_in,
        C_out,
        kernel_size=3,
        padding=1,
        dilation=1,
        affine=True,
        stride=1,
        track_running_stats=True,
    ):
        super().__init__(
            nn.ReLU(inplace=False),
            nn.Conv2d(
                C_in,
                C_out,
                kernel_size,
                padding=padding,
                dilation=dilation,
                bias=False,
                stride=stride,
            ),
            nn.BatchNorm2d(
                C_out, affine=affine, track_running_stats=track_running_stats
            ),
        )


class FactorizedReduce(nn.Module):
    def __init__(self, C_in, C_out, affine=True):
        super().__init__()
        self.relu = nn.ReLU(inplace=False)
        self.conv_1 = nn.Conv2d(C_in, C_out // 2, 1, stride=2, padding=0)
        self.conv_2 = nn.Conv2d(C_in, C_out // 2, 1, stride=2, padding=0)
        self.bn = nn.BatchNorm2d(C_out, affine=affine)

    def forward(self, x):
        x = self.relu(x)
        x = torch.cat([self.conv_1(x), self.conv_2(x[:, :, 1:, 1:])], dim=1)
        x = self.bn(x)
        return x


class Zero(nn.Module):
    def __init__(self, stride=1):
        super().__init__()
        self.stride = stride

    def forward(self, x):
        if self.stride == 1:
            return x.mul(0.0)
        return x[:, :, :: self.stride, :: self.stride].mul(0.0)


class Identity(nn.Module):
    def forward(self, x):
        return x


class ResNetBasicblock(nn.Module):
    def __init__(self, inplanes, planes, stride, affine=True):
        super(ResNetBasicblock, self).__init__()
        assert stride == 1 or stride == 2, "invalid stride {:}".format(stride)
        self.conv_a = ConvReLUBN(
            inplanes,
            planes,
            3,
            1,
            1,
            affine,
            stride,
        )
        self.conv_b = ConvReLUBN(
            planes,
            planes,
            3,
            1,
            1,
            affine,
            1,
        )
        self.downsample = nn.Sequential(
            nn.AvgPool2d(kernel_size=2, stride=2, padding=0),
            nn.Conv2d(inplanes, planes, kernel_size=1, stride=1, padding=0, bias=False),
        )
        self.in_dim = inplanes
        self.out_dim = planes
        self.stride = stride
        self.num_conv = 2

    def forward(self, inputs):
        basicblock = self.conv_a(inputs)
        basicblock = self.conv_b(basicblock)

        if self.downsample is not None:
            residual = self.downsample(inputs)
        else:
            residual = inputs
        return residual + basicblock
