#include <THC/THC.h>
#include<torch/extension.h>
#include <stdbool.h>
#include <stdio.h>

#define real float

int BilinearSamplerBHWD_updateOutput_cuda_kernel(int oc, int ow, int oh, int ob, int ic, int ih, int iw, int ib,
                                                 float *inputImages, int isb, int isc, int ish, int isw,
                                                 float *grids, int gsb, int gsc, int gsh, int gsw,
                                                 float *output, int osb, int osc, int osh, int osw,
                                                 cudaStream_t stream);

int BilinearSamplerBHWD_updateGradInput_cuda_kernel(int goc, int gow, int goh, int gob, int ic, int ih, int iw, int ib, 
                                                    float *inputImages, int isb, int isc, int ish, int isw,
                                                    float *grids, int gsb, int gsc, int gsh, int gsw,
                                                    float *gradInputImages, int gisb, int gisc, int gish, int gisw,
                                                    float *gradGrids, int ggsb, int ggsc, int ggsh, int ggsw,
                                                    float *gradOutput, int gosb, int gosc, int gosh, int gosw,
                                                    cudaStream_t stream);

// this symbol will be resolved automatically from PyTorch libs
extern THCState *state;

// Bilinear sampling is done in BHWD (coalescing is not obvious in BDHW)
// we assume BHWD format in inputImages
// we assume BHW(YX) format on grids

int BilinearSamplerBHWD_updateOutput_cuda(THCudaTensor *inputImages, THCudaTensor *grids, THCudaTensor *output){
  int success = 0;
  success = BilinearSamplerBHWD_updateOutput_cuda_kernel(THCudaTensor_size(state, output, 1),
							 THCudaTensor_size(state, output, 3),
							 THCudaTensor_size(state, output, 2),
							 THCudaTensor_size(state, output, 0),
							 THCudaTensor_size(state, inputImages, 1),
							 THCudaTensor_size(state, inputImages, 2),
							 THCudaTensor_size(state, inputImages, 3),
							 THCudaTensor_size(state, inputImages, 0),
							 THCudaTensor_data(state, inputImages),
							 THCudaTensor_stride(state, inputImages, 0),
							 THCudaTensor_stride(state, inputImages, 1),
							 THCudaTensor_stride(state, inputImages, 2),
							 THCudaTensor_stride(state, inputImages, 3),
							 THCudaTensor_data(state, grids),
							 THCudaTensor_stride(state, grids, 0),
							 THCudaTensor_stride(state, grids, 3),
							 THCudaTensor_stride(state, grids, 1),
							 THCudaTensor_stride(state, grids, 2),
							 THCudaTensor_data(state, output),
							 THCudaTensor_stride(state, output, 0),
							 THCudaTensor_stride(state, output, 1),
							 THCudaTensor_stride(state, output, 2),
							 THCudaTensor_stride(state, output, 3),
							 THCState_getCurrentStream(state));

  //check for errors
  if (!success) {
    THError("aborting");
  }
  return 1;
}

int BilinearSamplerBHWD_updateGradInput_cuda(THCudaTensor *inputImages, THCudaTensor *grids, THCudaTensor *gradInputImages,
                                        THCudaTensor *gradGrids, THCudaTensor *gradOutput)
{
  int success = 0;
  success = BilinearSamplerBHWD_updateGradInput_cuda_kernel(THCudaTensor_size(state, gradOutput, 1),
							    THCudaTensor_size(state, gradOutput, 3),
							    THCudaTensor_size(state, gradOutput, 2),
							    THCudaTensor_size(state, gradOutput, 0),
							    THCudaTensor_size(state, inputImages, 1),
							    THCudaTensor_size(state, inputImages, 2),
							    THCudaTensor_size(state, inputImages, 3),
							    THCudaTensor_size(state, inputImages, 0),
							    THCudaTensor_data(state, inputImages),
							    THCudaTensor_stride(state, inputImages, 0),
							    THCudaTensor_stride(state, inputImages, 1),
							    THCudaTensor_stride(state, inputImages, 2),
							    THCudaTensor_stride(state, inputImages, 3),
							    THCudaTensor_data(state, grids),
							    THCudaTensor_stride(state, grids, 0),
							    THCudaTensor_stride(state, grids, 3),
							    THCudaTensor_stride(state, grids, 1),
							    THCudaTensor_stride(state, grids, 2),
							    THCudaTensor_data(state, gradInputImages),
							    THCudaTensor_stride(state, gradInputImages, 0),
							    THCudaTensor_stride(state, gradInputImages, 1),
							    THCudaTensor_stride(state, gradInputImages, 2),
							    THCudaTensor_stride(state, gradInputImages, 3),
							    THCudaTensor_data(state, gradGrids),
							    THCudaTensor_stride(state, gradGrids, 0),
							    THCudaTensor_stride(state, gradGrids, 3),
							    THCudaTensor_stride(state, gradGrids, 1),
							    THCudaTensor_stride(state, gradGrids, 2),
							    THCudaTensor_data(state, gradOutput),
							    THCudaTensor_stride(state, gradOutput, 0),
							    THCudaTensor_stride(state, gradOutput, 1),
							    THCudaTensor_stride(state, gradOutput, 2),
							    THCudaTensor_stride(state, gradOutput, 3),
							    THCState_getCurrentStream(state));

  //check for errors
  if (!success) {
    THError("aborting");
  }
  return 1;
}

PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
  m.def("forward", &BilinearSamplerBHWD_updateOutput_cuda, "roi crop forward (CUDA)");
  m.def("backward", &BilinearSamplerBHWD_updateGradInput_cuda, "roi crop backward (CUDA)");
}