#include <torch/extension.h>
#include <ATen/cuda/CUDAContext.h>
#include <ATen/ops/matmul.h>

//#include <cuda.h>
//#include <cuda_runtime.h>

#include <vector>
#include<iostream>

const int MAX_THREADS_PER_BLOCK = 512;



template <typename Dtype>
__global__ void facs_fbs_Forward_cuda_kernel(const int n_threads,
                                   int out_channels, int in_channels,
                                   int height, int width,
                                   const Dtype* __restrict__ bottom_data,
                                   const Dtype* __restrict__ topk_score,
                                   const Dtype* __restrict__ indices,
                                   Dtype* __restrict__ top_data)
{
   int thread_idx = blockIdx.x * blockDim.x + threadIdx.x;
   if(thread_idx < n_threads)
   {
        unsigned int n = thread_idx / (out_channels*height*width);
        unsigned int c = (thread_idx / (height*width)) % out_channels;
        unsigned int h = (thread_idx / (width)) % height;
        unsigned int w = thread_idx % width;

        Dtype topk_score_val = topk_score[n*in_channels+c];
        Dtype in_c = indices[n*in_channels+c];

        unsigned long int bottom_index = ((n*in_channels+in_c)*height+h)*width+w;
        Dtype bottom_val = bottom_data[bottom_index];
        top_data[thread_idx] =  topk_score_val * bottom_val;
    }
}


template <typename Dtype>
__global__ void facs_fbs_Backward_cuda_kernel(const int n_threads,
                                    int out_channels, int in_channels,
                                    int height, int width,
                                    const Dtype* __restrict__ bottom_data,
                                    const Dtype* __restrict__ topk_score,
                                    const Dtype* __restrict__ indices,
                                    const Dtype* __restrict__ top_diff,
                                    Dtype* __restrict__ bottom_diff,
                                    Dtype* __restrict__ topk_score_diff
                                    )
{
   int thread_idx = blockIdx.x * blockDim.x + threadIdx.x;
   if(thread_idx < n_threads)
   {
        unsigned int n = thread_idx / (out_channels*height*width);
        unsigned int c = (thread_idx / (height*width)) % out_channels;
        unsigned int h = (thread_idx / (width)) % height;
        unsigned int w = thread_idx % width;

        Dtype topk_score_val = topk_score[n*in_channels+c];
        Dtype in_c = indices[n*in_channels+c];

        unsigned long int bottom_index = ((n*in_channels+in_c)*height+h)*width+w;
        Dtype bottom_val = bottom_data[bottom_index];

        bottom_diff[bottom_index] = topk_score_val * top_diff[thread_idx];
        topk_score_diff[((n*in_channels+c)*height+h)*width+w] = bottom_val * top_diff[thread_idx];
   }
}




std::vector<torch::Tensor> facs_fbs_cuda_forward(
    const int n_topk,
    torch::Tensor input,
    torch::Tensor topk_score,
    torch::Tensor indices)
{

     int n = input.size(0);
     int in_c = input.size(1);
     int h = input.size(2);
     int w = input.size(3);

     int out_c =n_topk;

  torch::Tensor output = torch::zeros({n, out_c, h , w}, input.options());

  const int n_threads = n * out_c * h * w;

  const dim3 blocks((n_threads - 1) / MAX_THREADS_PER_BLOCK + 1, 1, 1);

  c10::cuda::CUDAStream stream = c10::cuda::getCurrentCUDAStream();

  AT_DISPATCH_FLOATING_TYPES(output.type(), "facsfbs_cuda_forward", ([&] {
    facs_fbs_Forward_cuda_kernel<scalar_t><<<blocks, MAX_THREADS_PER_BLOCK,0,stream>>>(
        n_threads,
        out_c, in_c,
        h, w,
        (scalar_t*)input.data_ptr(),
        (scalar_t*)topk_score.data_ptr(),
        (scalar_t*)indices.data_ptr(),
        (scalar_t*)indices.data_ptr());
  }));


  return {output};
}

std::vector<torch::Tensor> facs_fbs_cuda_backward(
    const int n_topk,
    torch::Tensor input,
    torch::Tensor topk_score,
    torch::Tensor indices,
    torch::Tensor output_grad)
{
      int n = input.size(0);
     int in_c = input.size(1);
     int h = input.size(2);
     int w = input.size(3);

     int out_c = n_topk;

  torch::Tensor input_grad = torch::zeros_like(input);
  torch::Tensor topk_score_mid_grad = torch::zeros_like(input);

  const int n_threads = n * out_c * h * w;

  const dim3 blocks((n_threads - 1) / MAX_THREADS_PER_BLOCK + 1, 1, 1);

  c10::cuda::CUDAStream stream = c10::cuda::getCurrentCUDAStream();

  AT_DISPATCH_FLOATING_TYPES(output_grad.type(), "facsfbs_cuda_backward", ([&] {
    facs_fbs_Backward_cuda_kernel<scalar_t><<<blocks, MAX_THREADS_PER_BLOCK,0,stream>>>(
        n_threads,
        out_c, in_c,
        h, w,
        (scalar_t*)input.data_ptr(),
        (scalar_t*)topk_score.data_ptr(),
        (scalar_t*)indices.data_ptr(),
        (scalar_t*)output_grad.data_ptr(),
        (scalar_t*)input_grad.data_ptr(),
        (scalar_t*)topk_score_mid_grad.data_ptr()
        );
  }));

  torch::Tensor ones = torch::ones({h*w, 1}, input.options());

  topk_score_mid_grad = topk_score_mid_grad.reshape({n*in_c, h*w});

  torch::Tensor topk_score_grad = at::matmul(topk_score_mid_grad, ones);
  topk_score_grad = topk_score_grad.reshape({n,in_c, 1,1});

  return {input_grad, topk_score_grad};
}
