import torch.nn as nn
import torch.nn.functional as F


class ResNet(nn.Module):
    
    def __init__(self, activation, activation_params, rs, layers, num_classes):
        super(ResNet, self).__init__()
        self.conv1 = nn.Conv2d(3, 16, kernel_size=3, stride=1, padding=1)
        self.norm1 = nn.BatchNorm2d(16)
        self.act1 = activation(**activation_params)
        self.layers1 = self._make_layer(activation, activation_params, layers[0], 16, 16, 1)
        self.layers2 = self._make_layer(activation, activation_params, layers[1], 32, 16, 2)
        self.layers3 = self._make_layer(activation, activation_params, layers[2], 64, 32, 2)
        self.avgpool = nn.AvgPool2d(8)
        self.linear = nn.Linear(64, num_classes)
    
    def _make_layer(self, activation, activation_params, layer_count, channels, channels_in, stride):
        return nn.Sequential(
            ResBlock(activation, activation_params, channels, channels_in, stride),
            *[ResBlock(activation, activation_params, channels) for _ in range(layer_count-1)]
            )
    
    def forward(self, x):
        out = self.conv1(x)
        out = self.norm1(out)
        out = self.act1(out)
        out = self.layers1(out)
        out = self.layers2(out)
        out = self.layers3(out)
        out = self.avgpool(out)
        out = out.view(out.size(0), -1)
        out = self.linear(out)
        return out

class ResBlock(nn.Module):
    
    def __init__(self, activation, activation_params, num_filters, channels_in=None, stride=1):
        super(ResBlock, self).__init__()
        
        # uses 1x1 convolutions for downsampling
        if not channels_in or channels_in == num_filters:
            channels_in = num_filters
            self.projection = None
        else : 
            self.projection = IdentityPadding(num_filters, channels_in, stride)

        self.conv1 = nn.Conv2d(channels_in, num_filters, kernel_size=3, stride=stride, padding=1)
        self.bn1 = nn.BatchNorm2d(num_filters)
        self.act1 = activation(**activation_params)
        self.conv2 = nn.Conv2d(num_filters, num_filters, kernel_size=3, stride=1, padding=1)
        self.bn2 = nn.BatchNorm2d(num_filters)
        self.act2 = activation(**activation_params)

    def forward(self, x):
        residual = x
        out = self.conv1(x)
        out = self.bn1(out)
        out = self.act1(out)
        out = self.conv2(out)
        out = self.bn2(out)

        if self.projection:
            residual = self.projection(x)

        out += residual
        out = self.act2(out)
        return out


# various projection options to change number of filters in residual connection
# option A from paper
class IdentityPadding(nn.Module):
    def __init__(self, num_filters, channels_in, stride):
        super(IdentityPadding, self).__init__()
        # with kernel_size=1, max pooling is equivalent to identity mapping with stride
        self.identity = nn.MaxPool2d(1, stride=stride)
        self.num_zeros = num_filters - channels_in
    
    def forward(self, x):
        out = F.pad(x, (0, 0, 0, 0, 0, self.num_zeros))
        out = self.identity(out)
        return out
    

    
def resnet20(activation, activation_params, rs, num_classes) :
    return ResNet(activation, activation_params, rs, [3, 3, 3], num_classes=num_classes)
    
def resnet32(activation, activation_params, rs, num_classes) :
    return ResNet(activation, activation_params, rs, [5, 5, 5], num_classes=num_classes)
    
def resnet44(activation, activation_params, rs, num_classes) :
    return ResNet(activation, activation_params, rs, [7, 7, 7], num_classes=num_classes)


