'''Functional ResNet in PyTorch.
Does not support dataparallel
'''
from turtle import forward
import torch
import torch.nn as nn
import torch.nn.functional as F
from models.functional_modules import *
# from wmodules import *
from torch.autograd import Variable
from collections import OrderedDict


def conv3x3(in_planes, out_planes, stride=1):
    return FConv2d(in_planes, out_planes, kernel_size=3, stride=stride, padding=1, bias=False)

class FBasicBlock(nn.Module):
    expansion = 1

    def __init__(self, in_planes, planes, stride=1):
        super(FBasicBlock, self).__init__()
        self.conv1 = conv3x3(in_planes, planes, stride)
        self.bn1 = FBatchNorm2d(planes)
        self.conv2 = conv3x3(planes, planes)
        self.bn2 = FBatchNorm2d(planes)

        self.shortcut = nn.Sequential()
        if stride != 1 or in_planes != self.expansion*planes:
            self.shortcut = nn.Sequential(OrderedDict([
                ('conv', FConv2d(in_planes, self.expansion*planes, kernel_size=1, stride=stride, bias=False)),
                ('bn', FBatchNorm2d(self.expansion*planes))
                ]))

    def forward(self, x, prefix, named_dict):
        pf, nd = prefix, named_dict
        out = F.relu(self.bn1(self.conv1(x, nd.get('%s.conv1.weight'%pf), nd.get('%s.conv1.bias'%pf)),nd.get('%s.bn1.weight'%pf),nd.get('%s.bn1.bias'%pf)))
        out = self.bn2(self.conv2(out, nd.get('%s.conv2.weight'%pf), nd.get('%s.conv2.bias'%pf)), nd.get('%s.bn2.weight'%pf),nd.get('%s.bn2.bias'%pf))
        if len(self.shortcut):
            out += self.shortcut[1](self.shortcut[0](x, nd.get('%s.shortcut.conv.weight'%pf), nd.get('%s.shortcut.conv.bias'%pf)), nd.get('%s.shortcut.bn.weight'%pf),nd.get('%s.shortcut.bn.bias'%pf))
        else:
            out += self.shortcut(x)
        out = F.relu(out)
        return out

class PreActBlock(nn.Module):
    '''Pre-activation version of the BasicBlock.'''
    expansion = 1

    def __init__(self, in_planes, planes, stride=1):
        super(PreActBlock, self).__init__()
        self.bn1 = FBatchNorm2d(in_planes)
        self.conv1 = conv3x3(in_planes, planes, stride)
        self.bn2 = FBatchNorm2d(planes)
        self.conv2 = conv3x3(planes, planes)

        self.shortcut = nn.Sequential()
        if stride != 1 or in_planes != self.expansion*planes:
            self.shortcut = nn.Sequential(OrderedDict([
                ('conv', FConv2d(in_planes, self.expansion*planes, kernel_size=1, stride=stride, bias=False))
            ]))

    def forward(self, x):
        out = F.relu(self.bn1(x))
        shortcut = self.shortcut(out)
        out = self.conv1(out)
        out = self.conv2(F.relu(self.bn2(out)))
        out += shortcut
        return out


class Bottleneck(nn.Module):
    expansion = 4

    def __init__(self, in_planes, planes, stride=1):
        super(Bottleneck, self).__init__()
        self.conv1 = FConv2d(in_planes, planes, kernel_size=1, bias=False)
        self.bn1 = FBatchNorm2d(planes)
        self.conv2 = FConv2d(planes, planes, kernel_size=3, stride=stride, padding=1, bias=False)
        self.bn2 = FBatchNorm2d(planes)
        self.conv3 = FConv2d(planes, self.expansion*planes, kernel_size=1, bias=False)
        self.bn3 = FBatchNorm2d(self.expansion*planes)

        self.shortcut = nn.Sequential()
        if stride != 1 or in_planes != self.expansion*planes:
            self.shortcut = nn.Sequential(OrderedDict([
                ('conv',FConv2d(in_planes, self.expansion*planes, kernel_size=1, stride=stride, bias=False)),
                ('bn',FBatchNorm2d(self.expansion*planes))
            ]))

    def forward(self, x):
        out = F.relu(self.bn1(self.conv1(x)))
        out = F.relu(self.bn2(self.conv2(out)))
        out = self.bn3(self.conv3(out))
        out += self.shortcut(x)
        out = F.relu(out)
        return out


class PreActBottleneck(nn.Module):
    '''Pre-activation version of the original Bottleneck module.'''
    expansion = 4

    def __init__(self, in_planes, planes, stride=1):
        super(PreActBottleneck, self).__init__()
        self.bn1 = FBatchNorm2d(in_planes)
        self.conv1 = FConv2d(in_planes, planes, kernel_size=1, bias=False)
        self.bn2 = FBatchNorm2d(planes)
        self.conv2 = FConv2d(planes, planes, kernel_size=3, stride=stride, padding=1, bias=False)
        self.bn3 = FBatchNorm2d(planes)
        self.conv3 = FConv2d(planes, self.expansion*planes, kernel_size=1, bias=False)

        self.shortcut = nn.Sequential()
        if stride != 1 or in_planes != self.expansion*planes:
            self.shortcut = nn.Sequential(OrderedDict([
                ('conv', FConv2d(in_planes, self.expansion*planes, kernel_size=1, stride=stride, bias=False))
            ]))

    def forward(self, x):
        out = F.relu(self.bn1(x))
        shortcut = self.shortcut(out)
        out = self.conv1(out)
        out = self.conv2(F.relu(self.bn2(out)))
        out = self.conv3(F.relu(self.bn3(out)))
        out += shortcut
        return out

def through_layer(layers, input, prefix, named_dict):
    for i in range(len(layers)):
        input = layers[i](input, prefix+'.%d'%i, named_dict)
    return input

class FResNet(nn.Module):
    def __init__(self, block, num_blocks, num_classes=10):
        super(FResNet, self).__init__()
        self.in_planes = 64

        self.conv1 = conv3x3(3,64)
        self.bn1 = FBatchNorm2d(64)
        self.layer1 = self._make_layer(block, 64, num_blocks[0], stride=1)
        self.layer2 = self._make_layer(block, 128, num_blocks[1], stride=2)
        self.layer3 = self._make_layer(block, 256, num_blocks[2], stride=2)
        self.layer4 = self._make_layer(block, 512, num_blocks[3], stride=2)
        self.linear = WLinear(512*block.expansion, num_classes)

    def _make_layer(self, block, planes, num_blocks, stride):
        strides = [stride] + [1]*(num_blocks-1)
        layers = []
        for stride in strides:
            layers.append(block(self.in_planes, planes, stride))
            self.in_planes = planes * block.expansion
        return nn.Sequential(*layers)
    

    def forward(self, x, prefix, named_dict, lin=0, lout=6):
        out = x
        pf, nd = prefix, named_dict
        if lin < 1 and lout > -1:
            out = self.conv1(out, nd.get('%sconv1.weight'%pf), nd.get('%sconv1.bias'%pf))
            out = self.bn1(out, nd.get('%sbn1.weight'%pf), nd.get('%sbn1.bias'%pf))
            out = F.relu(out)
        if lin < 2 and lout > 0:
            # out = self.layer1(out)
            out = through_layer(self.layer1, out, prefix+"layer1", named_dict)
        if lin < 3 and lout > 1:
            # out = self.layer2(out)
            out = through_layer(self.layer2, out, prefix+"layer2", named_dict)
        if lin < 4 and lout > 2:
            # out = self.layer3(out)
            out = through_layer(self.layer3, out, prefix+"layer3", named_dict)
        if lin < 5 and lout > 3:
            # out = self.layer4(out)
            out = through_layer(self.layer4, out, prefix+"layer4", named_dict)
        if lin < 6 and lout > 4:
            out = F.avg_pool2d(out, 4)
            out = out.view(out.size(0), -1)
        if lout > 5:
            out = self.linear(out, nd.get('%slinear.weight'%pf), nd.get('%slinear.bias'%pf))
        return out

def FResNet18(num_classes=10):
    return FResNet(FBasicBlock, [2,2,2,2], num_classes)

# test()

if __name__ == "__main__":
    from resnet import ResNet18
    inputs = torch.randn(128,3,32,32)
    model = ResNet18()
    model(inputs)
    named_dict = dict(ResNet18().named_parameters())
    fmodel = FResNet18()
    outputs = fmodel(inputs, 'module.', named_dict)
    print(outputs.shape)
    
    # outputs = fmodels(inputs, 'model')