"""PyTorch implementation of ShiftResNet

ShiftResNet modifications written by Bichen Wu and Alvin Wan. Efficient CUDA
implementation of shift written by Peter Jin.

Reference:
[1] Bichen Wu, Alvin Wan, Xiangyu Yue, Peter Jin, Sicheng Zhao, Noah Golmant,
    Amir Gholaminejad, Joseph Gonzalez, Kurt Keutzer
    Shift: A Zero FLOP, Zero Parameter Alternative to Spatial Convolutions.
    arXiv:1711.08141
"""

import torch.nn as nn
import torch.nn.functional as F
import torch

from models.resnet import ResNet

class Shift3x3(nn.Module):
    """Reimplementation of shift module in Pytorch
    """
    def __init__(self):
        super().__init__()

    def forward(self, x):
        return self.shift(x)

    @staticmethod
    def shift(x):
        b, c, h, w = x.size()
        out = torch.zeros_like(x)
        r = c%9 # remainder
        g =  c//9 # number of channels in each shift group

        # print('Channels per shift: ', g)

        out[:,:g,1:,1:] = x[:,:g,:h-1,:w-1] #right-down
        out[:,g:g*2,1:,:] = x[:,g:g*2,:h-1,:] #down
        out[:,g*2:g*3,1:,:w-1] = x[:,g*2:g*3,:h-1,1:] #left-down

        out[:,g*3:g*4,:,1:] = x[:,g*3:g*4,:,:w-1] # right
        out[:,g*4:g*5+r,:,:] = x[:,g*4:g*5+r,:,:] # center
        out[:,g*5+r:g*6+r,:,:w-1] = x[:,g*5+r:g*6+r,:,1:] # left

        out[:,g*6+r:g*7+r,:h-1,1:] = x[:,g*6+r:g*7+r,1:,:w-1] # right-up
        out[:,g*7+r:g*8+r,:h-1,:] = x[:,g*7+r:g*8+r,1:,:] #up
        out[:,g*8:,:h-1,:w-1] = x[:,g*8:,1:,1:] # left-up
        return out


class ShiftConv(nn.Module):

    def __init__(self, in_planes, out_planes, stride=1, expansion=1):
        super(ShiftConv, self).__init__()
        self.expansion = expansion
        self.in_planes = in_planes
        self.out_planes = out_planes
        self.mid_planes = mid_planes = int(out_planes * self.expansion)

        self.conv1 = nn.Conv2d(
            in_planes, mid_planes, kernel_size=1, bias=False)
        self.bn1 = nn.BatchNorm2d(mid_planes)
        self.shift2 = Shift3x3()
        self.conv2 = nn.Conv2d(
            mid_planes, out_planes, kernel_size=1, bias=False, stride=stride)
        self.bn2 = nn.BatchNorm2d(out_planes)

        self.shortcut = nn.Sequential()
        if stride != 1 or in_planes != out_planes:
            self.shortcut = nn.Sequential(
                nn.Conv2d(
                  in_planes, out_planes, kernel_size=1, stride=stride,
                  bias=False),
                nn.BatchNorm2d(out_planes)
            )


    def forward(self, x):
        shortcut = self.shortcut(x)
        x = F.relu(self.bn1(self.conv1(x)))
        self.int_nchw = x.size()
        x = F.relu(self.bn2(self.conv2(self.shift2(x))))
        self.out_nchw = x.size()
        x = x + shortcut
        return x


def ShiftResNet20(expansion=1, num_classes=10):
    block = lambda in_planes, out_planes, stride: \
        ShiftConv(in_planes, out_planes, stride, expansion=expansion)
    return ResNet(block, [3, 3, 3], num_classes=num_classes)


def ShiftResNet56(expansion=1, num_classes=10):
    block = lambda in_planes, out_planes, stride: \
        ShiftConv(in_planes, out_planes, stride, expansion=expansion)
    return ResNet(block, [9, 9, 9], num_classes=num_classes)


def ShiftResNet110(expansion=1, num_classes=10):
    block = lambda in_planes, out_planes, stride: \
        ShiftConv(in_planes, out_planes, stride, expansion=expansion)
    return ResNet(block, [18, 18, 18], num_classes=num_classes)


if __name__ == "__main__":
    model = ShiftResNet20(expansion=9)
    x = torch.randn(1, 3, 32, 32)
    out = model(x)
    print(out.shape)
