import torch.nn as nn
from models.utils.utils import make_conv_block, get_activation_function


class BasicBlock(nn.Module):
    expansion = 1

    def __init__(self, inplanes, planes, activation_generator, oper_order, stride=1, downsample=None, depth=20):
        super(BasicBlock, self).__init__()
        self.relu_generator = get_activation_function('relu')

        self.conv1 = make_conv_block(inplanes, planes, activation_generator, kernel_size=3, stride=stride, padding=1, oper_order=oper_order['full'])
        self.conv2 = make_conv_block(planes, planes * BasicBlock.expansion, activation_generator, kernel_size=3, stride=1, padding=1,oper_order=oper_order['front2'])#oper_order='cb')
        self.last = make_conv_block(planes, planes, activation_generator, oper_order['end1'])#'a')

        self.downsample = downsample

    def forward(self, x):
        residual = x

        out = self.conv1(x)
        out = self.conv2(out)

        if self.downsample is not None:
            residual = self.downsample(x)

        out = out + residual
        out = self.last(out)

        return out


class Bottleneck(nn.Module):
    expansion = 4

    def __init__(self, inplanes, planes, activation_generator, oper_order, stride=1, downsample=None):
        super(Bottleneck, self).__init__()
        self.conv1 = make_conv_block(inplanes, planes, activation_generator, oper_order['full'],
                                     kernel_size=1, stride=1, padding=0, bias=False)
        self.conv2 = make_conv_block(planes, planes, activation_generator,  oper_order['full'],
                                     kernel_size=3, stride=stride, padding=1, bias=False)
        self.conv3 = make_conv_block(planes, planes * Bottleneck.expansion, activation_generator, oper_order['front2'],
                                     kernel_size=1, stride=1, padding=0, bias=False)
        self.last = make_conv_block(planes * Bottleneck.expansion, planes * Bottleneck.expansion,
                                    activation_generator, oper_order['end1'])

        self.downsample = downsample

    def forward(self, x):
        residual = x

        out = self.conv1(x)
        out = self.conv2(out)
        out = self.conv3(out)

        if self.downsample is not None:
            residual = self.downsample(x)

        out = out + residual
        out = self.last(out)

        return out


class PreActBlock(nn.Module):
    expansion = 1

    def __init__(self, inplanes, planes, activation_generator, oper_order, stride=1, shortcut=None):
        super(PreActBlock, self).__init__()

        self.shortcut = shortcut

        self.prelayer = make_conv_block(channels_in=inplanes, channels_out=inplanes,
                                          activation_generator=activation_generator,
                                          oper_order=oper_order['end2'])
        self.block1 = make_conv_block(channels_in=inplanes, channels_out=planes,
                                      activation_generator=activation_generator,
                                      kernel_size=3, stride=stride, padding=1,
                                      oper_order=oper_order['full'])
        self.block2 = make_conv_block(channels_in=planes, channels_out=planes,
                                      kernel_size=3, stride=1, padding=1,
                                      oper_order='c')

    def forward(self, x):
        preact = self.prelayer(x)

        shortcut = x

        out = self.block1(preact)
        out = self.block2(out)

        if self.shortcut is not None:
            shortcut = self.shortcut(preact)

        out = out + shortcut
        return out


class PreActBottleneck(nn.Module):
    expansion = 4

    def __init__(self, inplanes, planes, activation_generator, oper_order, stride=1, shortcut=None):
        super(PreActBottleneck, self).__init__()

        self.shortcut = shortcut

        self.prelayer = make_conv_block(channels_in=inplanes, channels_out=inplanes,
                                      activation_generator=activation_generator,
                                      oper_order=oper_order['end2'])
        self.block1 = make_conv_block(channels_in=inplanes, channels_out=planes,
                                      activation_generator=activation_generator,
                                      kernel_size=1, stride=1, padding=0,
                                      oper_order=oper_order['full'])
        self.block2 = make_conv_block(channels_in=planes, channels_out=planes,
                                      activation_generator=activation_generator,
                                      kernel_size=3, stride=stride, padding=1,
                                      oper_order=oper_order['full'])
        self.block3 = make_conv_block(channels_in=planes, channels_out=planes * self.expansion,
                                      activation_generator=activation_generator,
                                      kernel_size=1, stride=1, padding=0,
                                      oper_order=oper_order['front1'])

    def forward(self, x):
        preact = self.prelayer(x)

        shortcut = x if self.shortcut is None else self.shortcut(preact)

        out = self.block1(preact)
        out = self.block2(out)
        out = self.block3(out)

        out = out + shortcut

        return out