import torch
from torch.autograd import Function

from .utils import Dtype, Stream, load_kernel


CUDA_NUM_THREADS = 1024

kernel_loop = '''
#define CUDA_KERNEL_LOOP(i, n)                        \
  for (int i = blockIdx.x * blockDim.x + threadIdx.x; \
      i < (n);                                       \
      i += blockDim.x * gridDim.x)
'''


def GET_BLOCKS(N):
    return (N + CUDA_NUM_THREADS - 1) // CUDA_NUM_THREADS


_dotproduction2_refpad_forward_kernel = kernel_loop + '''
extern "C"
__global__ void dotproduction2_refpad_forward_kernel(
const ${Dtype}* bottom1_data, const ${Dtype}* bottom2_data, ${Dtype}* top_data) {
  CUDA_KERNEL_LOOP(index, ${nthreads}) {
    const int n = index / ${input_channels} / ${top_height} / ${top_width};
    const int c = (index / ${top_height} / ${top_width}) % ${input_channels};
    const int h = (index / ${top_width}) % ${top_height};
    const int w = index % ${top_width};
    const int offset_center = index;
    for (int kh = 0; kh < ${kernel}; ++kh) {
      for (int kw = 0; kw < ${kernel}; ++kw) {
        int h_in = -${pad} + h + kh;
        int w_in = -${pad} + w + kw;
        const int offset_top = ((n * ${input_channels} + c) * ${kernel} * ${kernel} + (kh * ${kernel} + kw)) * ${top_height} * ${top_width} + h * ${top_width} + w;
        int offset_bottom;
        if ((h_in >= 0) && (h_in < ${bottom_height}) && (w_in >= 0) && (w_in < ${bottom_width})) {
          offset_bottom = ((n * ${input_channels} + c) * ${bottom_height} + h_in) * ${bottom_width} + w_in;
        }
        else {
          if (h_in < 0) h_in = -h_in;
          if (h_in >= ${bottom_height}) h_in = 2 * (${bottom_height} - 1) - h_in;
          if (w_in < 0) w_in = -w_in;
          if (w_in >= ${bottom_width}) w_in = 2 * (${bottom_width} - 1) - w_in;
          offset_bottom = ((n * ${input_channels} + c) * ${bottom_height} + h_in) * ${bottom_width} + w_in;
        }
        top_data[offset_top] = bottom1_data[offset_center] * bottom2_data[offset_bottom];
      }
    }
  }
}
'''


_dotproduction2_refpad_input1_backward_kernel = kernel_loop + '''
extern "C"
__global__ void dotproduction2_refpad_input1_backward_kernel(
    const ${Dtype}* const top_diff, ${Dtype}* const bottom2_data, ${Dtype}* bottom_diff) {
  CUDA_KERNEL_LOOP(index, ${nthreads}) {
    const int n = index / ${input_channels} / ${bottom_height} / ${bottom_width};
    const int c = (index / ${bottom_height} / ${bottom_width}) % ${input_channels};
    const int h = (index / ${bottom_width}) % ${bottom_height};
    const int w = index % ${bottom_width};
    ${Dtype} value = 0;
    for (int kh = 0; kh < ${kernel}; ++kh) {
      for (int kw = 0; kw < ${kernel}; ++kw) {
        int h_in = -${pad} + h + kh;
        int w_in = -${pad} + w + kw;
        int offset_bottom;
        if ((h_in >= 0) && (h_in < ${bottom_height}) && (w_in >= 0) && (w_in < ${bottom_width})) {
          offset_bottom = ((n * ${input_channels} + c) * ${bottom_height} + h_in) * ${bottom_width} + w_in;
        }
        else {
          if (h_in < 0) h_in = -h_in;
          if (h_in >= ${bottom_height}) h_in = 2 * (${bottom_height} - 1) - h_in;
          if (w_in < 0) w_in = -w_in;
          if (w_in >= ${bottom_width}) w_in = 2 * (${bottom_width} - 1) - w_in;
          offset_bottom = ((n * ${input_channels} + c) * ${bottom_height} + h_in) * ${bottom_width} + w_in;
        }
        const int offset_top = ((n * ${input_channels} + c) * ${kernel} * ${kernel} + (kh * ${kernel} + kw)) * ${top_height} * ${top_width} + h * ${top_width} + w;
        value += top_diff[offset_top] * bottom2_data[offset_bottom];
      }
    }
    bottom_diff[index] = value;
  }
}
'''


_dotproduction2_refpad_input2_backward_kernel = kernel_loop + '''
extern "C"
__global__ void dotproduction2_refpad_input2_backward_kernel(
    const ${Dtype}* const top_diff, ${Dtype}* const bottom1_data, ${Dtype}* bottom_diff) {
  CUDA_KERNEL_LOOP(index, ${nthreads}) {
    const int n = index / ${input_channels} / (${bottom_height} + 2 * ${pad}) / (${bottom_width} + 2 * ${pad});
    const int c = (index / (${bottom_height} + 2 * ${pad}) / (${bottom_width} + 2 * ${pad})) % ${input_channels};
    const int h = (index / (${bottom_width} + 2 * ${pad})) % (${bottom_height} + 2 * ${pad});
    const int w = index % (${bottom_width} + 2 * ${pad});
    ${Dtype} value = 0;
    for (int kh = 0; kh < ${kernel}; ++kh) {
      for (int kw = 0; kw < ${kernel}; ++kw) {
        const int h_out = h - kh;
        const int w_out = w - kw;
        if ((h_out >= 0) && (h_out < ${top_height}) && (w_out >= 0) && (w_out < ${top_width})) {
          const int offset_top = ((n * ${input_channels} + c) * ${kernel} * ${kernel} + (kh * ${kernel} + kw)) * ${top_height} * ${top_width} + h_out * ${top_width} + w_out;
          const int offset_bottom = ((n * ${input_channels} + c) * ${bottom_height} + h_out) * ${bottom_width} + w_out;
          value += top_diff[offset_top] * bottom1_data[offset_bottom];
        }
      }
    }
    bottom_diff[index] = value;
  }
}
'''


class Dotproduction2Refpad(Function):
    @staticmethod
    def forward(ctx, input1, input2, kernel_size, padding):
        assert padding == (kernel_size - 1) / 2

        ctx.kernel_size, ctx.padding = kernel_size, padding
        assert input1.dim() == 4 and input1.is_cuda
        batch_size, input_channels, input_height, input_width = input1.size()
        output_height = input_height
        output_width = input_width
        output = input1.new(batch_size, input_channels,
                            kernel_size ** 2, output_height * output_width)
        n = output.numel() // output.shape[2]  # threads in total
        with torch.cuda.device_of(input1):
            f = load_kernel('dotproduction2_refpad_forward_kernel',
                            _dotproduction2_refpad_forward_kernel,
                            Dtype=Dtype(input1), 
                            nthreads=n,
                            num=batch_size, input_channels=input_channels,
                            bottom_height=input_height, bottom_width=input_width,
                            top_height=output_height, top_width=output_width,
                            kernel=kernel_size, 
                            pad=padding)
            f(block=(CUDA_NUM_THREADS, 1, 1),
              grid=(GET_BLOCKS(n), 1, 1),
              args=[input1.data_ptr(), input2.data_ptr(), output.data_ptr()],
              stream=Stream(ptr=torch.cuda.current_stream().cuda_stream))
        ctx.save_for_backward(input1, input2)
        return output

    @staticmethod
    def backward(ctx, grad_output):
        kernel_size, padding= ctx.kernel_size, ctx.padding
        input1, input2 = ctx.saved_tensors
        assert grad_output.is_cuda
        if not grad_output.is_contiguous():
            grad_output = grad_output.contiguous()
        batch_size, input_channels, input_height, input_width = input1.size()
        output_height = input_height
        output_width = input_width
        grad_input1, grad_input2 = None, None
        opt = dict(Dtype=Dtype(grad_output),
                   num=batch_size, input_channels=input_channels,
                   bottom_height=input_height, bottom_width=input_width,
                   top_height=output_height, top_width=output_width,
                   kernel=kernel_size,
                   pad=padding)
        with torch.cuda.device_of(input1):
            if ctx.needs_input_grad[0]:
                grad_input1 = input1.new(input1.size())
                n = grad_input1.numel()
                opt['nthreads'] = n
                f = load_kernel('dotproduction2_refpad_input1_backward_kernel',
                                _dotproduction2_refpad_input1_backward_kernel, **opt)
                f(block=(CUDA_NUM_THREADS, 1, 1),
                  grid=(GET_BLOCKS(n), 1, 1),
                  args=[grad_output.data_ptr(), input2.data_ptr(), grad_input1.data_ptr()],
                  stream=Stream(ptr=torch.cuda.current_stream().cuda_stream))
        with torch.cuda.device_of(input2):
            if ctx.needs_input_grad[1]:
                grad_input2 = input2.new(
                    batch_size, input_channels, input_height + 2 * padding, input_width + 2 * padding)
                n = grad_input2.numel()
                opt['nthreads'] = n
                f = load_kernel('dotproduction2_refpad_input2_backward_kernel',
                                _dotproduction2_refpad_input2_backward_kernel, **opt)
                f(block=(CUDA_NUM_THREADS, 1, 1),
                  grid=(GET_BLOCKS(n), 1, 1),
                  args=[grad_output.data_ptr(), input1.data_ptr(), grad_input2.data_ptr()],
                  stream=Stream(ptr=torch.cuda.current_stream().cuda_stream))
                grad_input2[:, :, padding + 1:2 * padding + 1,
                            :] += torch.flip(grad_input2[:, :, :padding, :], dims=[2])

                grad_input2[:, :, input_height - 1:input_height + padding - 1,
                            :] += torch.flip(grad_input2[:, :, input_height + padding:, :], dims=[2])

                grad_input2[:, :, :, padding + 1:2 * padding +
                            1] += torch.flip(grad_input2[:, :, :, :padding], dims=[3])

                grad_input2[:, :, :, input_width - 1:input_width + padding - 1] += torch.flip(
                    grad_input2[:, :, :, input_width + padding:], dims=[3])

                grad_input2 = grad_input2[:, :, padding:padding +
                                          input_height, padding:padding + input_width]

        return grad_input1, grad_input2, None, None, None, None


def dotproduction2_refpad(input1, input2, kernel_size=3, padding=0):
    assert input1.dim() == 4
    if input1.is_cuda:
        out = Dotproduction2Refpad.apply(
            input1, input2, kernel_size, padding)
    else:
        raise NotImplementedError
    return out


def test_dotproduction2_refpad():
    import os
    os.environ["CUDA_VISIBLE_DEVICES"] = '0'
    kernel_size = 7
    padding = 3
    n, c, in_height, in_width = 2, 8, 9, 9
    out_height = 9
    out_width = 9
    x1 = torch.randn(n, c, in_height, in_width,
                     requires_grad=True).double().cuda()
    x2 = torch.randn(n, c, in_height, in_width,
                     requires_grad=True).double().cuda()

    y1 = dotproduction2_refpad(x1, x2, kernel_size=kernel_size, padding=padding)
    unfold_i = torch.nn.Unfold(kernel_size=1, padding=0)
    unfold_j = torch.nn.Unfold(kernel_size=kernel_size, padding=0)
    pad = torch.nn.ReflectionPad2d(padding)
    y2 = unfold_i(x1).view(n, c, 1, out_height * out_width) * unfold_j(pad(x2)).view(n, c, pow(kernel_size, 2), out_height * out_width)
    assert (y1 - y2).abs().max() < 1e-9

    gx11 = torch.autograd.grad(y1.mean(), x1, retain_graph=True)[0]
    gx12 = torch.autograd.grad(y1.mean(), x2, retain_graph=True)[0]
    gx21 = torch.autograd.grad(y2.mean(), x1, retain_graph=True)[0]
    gx22 = torch.autograd.grad(y2.mean(), x2, retain_graph=True)[0]
    assert (gx11 - gx21).abs().max() < 1e-9
    assert (gx12 - gx22).abs().max() < 1e-9

    from functools import partial
    assert torch.autograd.gradcheck(partial(dotproduction2_refpad, kernel_size=kernel_size, padding=padding), (x1, x2))
    print('test case passed')

if __name__ == '__main__':
    test_dotproduction2_refpad()
