"""
Pixel shuffle layer extended to support volumetric data
(Real-Time Single Image and Video Super-Resolution Using an Efficient Sub-Pixel Convolutional Neural Network, CVPR17)
"""


from torch.autograd import Function
import torch
import torch.nn as nn

class PixelShuffle3dFunction(Function):
    @staticmethod
    def forward(ctx, input, upscale_factor):
        # initialize for backward and shape
        ctx.save_for_backward(input, upscale_factor)
        intermediateShape = 8*[0]

        # check if input is bached or not and initialize indexes and shapes
        batched = False
        batchSize = 0
        inputStartIdx = 0
        outShapeIdx = 0
        if input.dim() == 5:
            batched = True
            batchSize = input.size(0)
            inputStartIdx = 1
            outShapeIdx = 1
            outShape = 5*[0]
            outShape[0] = batchSize
        else:
            outShape = torch.zeros(4*[0])

        # input is of size h/r w/r d/r, r^3*c output should be h, w, d, c
        channels = input.size(inputStartIdx) // (upscale_factor.item()**3)
        inHeight = input.size(inputStartIdx + 1)
        inWidth = input.size(inputStartIdx + 2)
        inDepth = input.size(inputStartIdx + 3)

        intermediateShape[0] = batchSize
        intermediateShape[1] = channels
        intermediateShape[2] = upscale_factor.item()
        intermediateShape[3] = upscale_factor.item()
        intermediateShape[4] = upscale_factor.item()
        intermediateShape[5] = inHeight
        intermediateShape[6] = inWidth
        intermediateShape[7] = inDepth

        # channels should be spatial dimensions after that
        outShape[outShapeIdx] = channels
        outShape[outShapeIdx + 1] = inHeight * upscale_factor.item()
        outShape[outShapeIdx + 2] = inWidth * upscale_factor.item()
        outShape[outShapeIdx + 3] = inDepth * upscale_factor.item()

        # reshape and copy
        inputView = input.view(intermediateShape)

        shuffleOut = inputView.permute(0, 1, 5, 2, 6, 3, 7, 4).clone()

        output = shuffleOut.view(outShape)

        return output
    @staticmethod
    def backward(ctx, grad_output):
        # get values from forward and initialize shape
        input, upscale_factor = ctx.saved_tensors
        intermediateShape =  8*[0]
        # grad_upscale_factor is not necessary but needs to be there
        grad_input = grad_upscale_factor = None

        # check if input is bached or not and initialize indexes and shapes
        batchSize = 0
        inputStartIdx = 0
        if input.dim() == 5:
            batchSize = input.size(0)
            inputStartIdx = 1

        # part of spatial dimension becomes a channel
        channels = input.size(inputStartIdx) // (upscale_factor.item()**3)
        inHeight = input.size(inputStartIdx + 1)
        inWidth = input.size(inputStartIdx + 2)
        inDepth = input.size(inputStartIdx + 3)

        intermediateShape[0] = batchSize
        intermediateShape[1] = channels
        intermediateShape[2] = inHeight
        intermediateShape[3] = upscale_factor.item()
        intermediateShape[4] = inWidth
        intermediateShape[5] = upscale_factor.item()
        intermediateShape[6] = inDepth
        intermediateShape[7] = upscale_factor.item()

        # reshape and copy
        gradOutputView = grad_output.view(intermediateShape)
        shuffleIn = gradOutputView.permute(0, 1, 3, 5, 7, 2, 4, 6).clone()

        gradInput = shuffleIn.view( input.size())

        return gradInput, grad_upscale_factor

class PixelShuffle3d(nn.Module):
    def __init__(self, upscale_factor):
        super(PixelShuffle3d, self).__init__()
        # torch autograd only saves torch tensors in forward
        self.upscale_factor = torch.tensor(upscale_factor, dtype=torch.int64)
    def forward(self, input):
        # apply function
        return PixelShuffle3dFunction.apply(input, self.upscale_factor)

    def extra_repr(self):
        # string representation
        return 'upscale_factor={}'.format(
            self.upscale_factor
        )
