#include <torch/extension.h>
#include <ATen/cuda/CUDAContext.h>
#include <ATen/ops/matmul.h>

//#include <cuda.h>
//#include <cuda_runtime.h>

#include <vector>
#include<iostream>

const int MAX_THREADS_PER_BLOCK = 512;


template <typename Dtype>
__global__ void CMFeeder_Forward_cuda_kernel(const int n_threads,
                                        const int in_channels,
                                       const int height, const int width,
                                       const int n_modules,
                                       const Dtype* __restrict__ in,
                                        Dtype* __restrict__ out)
{
   int thread_idx = blockIdx.x * blockDim.x + threadIdx.x;
   if(thread_idx < n_threads)
   {
        const int out_channels = in_channels * n_modules;

        int n = thread_idx / (out_channels * height * width);
        int c = (thread_idx / (height * width)) % out_channels;
        int h = (thread_idx / width ) % height;
        int w = thread_idx % width;

        const int in_c = c % in_channels;

        out[thread_idx]= in[n * in_channels * height * width + in_c * height * width + h * width + w];


    }
}

template <typename Dtype>
__global__ void CMFeeder_Backward_cuda_kernel(const int n_threads,
                                             const int in_channels,
                                        const int height, const int width,
                                        const int n_modules,
                                        const Dtype* __restrict__ top_diff,
                                        Dtype* __restrict__ bottom_diff)
{
     int thread_idx = blockIdx.x * blockDim.x + threadIdx.x;
   if(thread_idx < n_threads)
    {

        int n = thread_idx / (in_channels * height *width);
        int c = (thread_idx / (height * width)) % in_channels;
        int h = (thread_idx / width ) % height;
        int w = thread_idx % width;

        const int out_channels = in_channels * n_modules;
        const int batch_offset = n * out_channels * height * width;
        const int spatial_offset = height * width;
        const int local_offset = h * width + w;

        Dtype diff = 0;
        for(int m=0;m <n_modules;m++)
        {
            int out_c = m * in_channels + c;
            diff += top_diff[batch_offset + out_c * spatial_offset + local_offset];
        }

        bottom_diff[thread_idx] = diff;
    }
}



std::vector<torch::Tensor> cm_feeder_cuda_forward(
    const int n_modules,
    torch::Tensor input)
{

     int n = input.size(0);
     int in_c = input.size(1);
     int h = input.size(2);
     int w = input.size(3);

     int out_c = in_c * n_modules;

 torch::Tensor output = torch::zeros({n, out_c, h , w}, input.options());

  const int n_threads = n * out_c * h * w; // or output.numel()

  const dim3 blocks((n_threads - 1) / MAX_THREADS_PER_BLOCK + 1, 1, 1);

  c10::cuda::CUDAStream stream = c10::cuda::getCurrentCUDAStream();

  AT_DISPATCH_FLOATING_TYPES(output.type(), "cm_feeder_cuda_forward", ([&] {
    CMFeeder_Forward_cuda_kernel<scalar_t><<<blocks, MAX_THREADS_PER_BLOCK,0,stream>>>(
        n_threads,
        in_c,
        h, w,
        n_modules,
        (scalar_t*)input.data_ptr(),
        (scalar_t*)output.data_ptr());
  }));

//std::cout << cudaGetErrorString(cudaGetLastError()) << "\n";
//AT_CUDA_CHECK(cudaStreamSynchronize(stream));

  return {output};
}

std::vector<torch::Tensor> cm_feeder_cuda_backward(
    const int n_modules,
    torch::Tensor output_grad)
{
      int n = output_grad.size(0);
     int out_c = output_grad.size(1);
     int h = output_grad.size(2);
     int w = output_grad.size(3);

     int in_c = out_c / n_modules;

  torch::Tensor input_grad = torch::zeros({n, in_c, h, w}, output_grad.options());

  const int n_threads = n * in_c * h * w;

  const dim3 blocks((n_threads - 1) / MAX_THREADS_PER_BLOCK + 1, 1, 1);

  c10::cuda::CUDAStream stream = c10::cuda::getCurrentCUDAStream();

  AT_DISPATCH_FLOATING_TYPES(output_grad.type(), "cm_feeder_cuda_backward", ([&] {
    CMFeeder_Backward_cuda_kernel<scalar_t><<<blocks, MAX_THREADS_PER_BLOCK,0,stream>>>(
        n_threads,
        in_c,
        h, w,
        n_modules,
        (scalar_t*)output_grad.data_ptr(),
        (scalar_t*)input_grad.data_ptr()
        );
  }));

  //std::cout << cudaGetErrorString(cudaGetLastError()) << "\n";
  //AT_CUDA_CHECK(cudaStreamSynchronize(stream));
 //std::cout << cudaGetErrorString(cudaGetLastError()) << "\n";

  return {input_grad};
}