from torch.autograd import Function, Variable
from torch.nn.modules.module import Module
import channelnorm_cuda

class ChannelNormFunction(Function):

    @staticmethod
    def forward(ctx, input1, norm_deg=2):
        assert input1.is_contiguous()
        b, _, h, w = input1.size()
        output = input1.new(b, 1, h, w).zero_()

        channelnorm_cuda.forward(input1, output, norm_deg)
        ctx.save_for_backward(input1, output)
        ctx.norm_deg = norm_deg

        return output

    @staticmethod
    def backward(ctx, grad_output):
        input1, output = ctx.saved_tensors

        grad_input1 = Variable(input1.new(input1.size()).zero_())

        channelnorm_cuda.backward(input1, output, grad_output.data,
                                              grad_input1.data, ctx.norm_deg)

        return grad_input1, None


class ChannelNorm(Module):

    def __init__(self, norm_deg=2):
        super(ChannelNorm, self).__init__()
        self.norm_deg = norm_deg

    def forward(self, input1):
        return ChannelNormFunction.apply(input1, self.norm_deg)

