#include <torch/extension.h>

#include <vector>
#include <stdio.h>
#include <mma.h>
#include <cuda_fp16.h>
#include <cuda_bf16.h>
#include <cub/block/block_load.cuh>
#include <cub/block/block_store.cuh>
#include "butterfly_cuda.h"
using namespace nvcuda;

// *************** FOR ERROR CHECKING *******************
#ifndef CUDA_RT_CALL
#define CUDA_RT_CALL( call )                                                                                           \
    {                                                                                                                  \
        auto status = static_cast<cudaError_t>( call );                                                                \
        if ( status != cudaSuccess )                                                                                   \
            fprintf( stderr,                                                                                           \
                     "ERROR: CUDA RT call \"%s\" in line %d of file %s failed "                                        \
                     "with "                                                                                           \
                     "%s (%d).\n",                                                                                     \
                     #call,                                                                                            \
                     __LINE__,                                                                                         \
                     __FILE__,                                                                                         \
                     cudaGetErrorString( status ),                                                                     \
                     status );                                                                                         \
    }
#endif  // CUDA_RT_CALL
// *************** FOR ERROR CHECKING *******************

#ifndef CUDA_CHECK_ERROR
// Define some error checking macros.
#define CHECK_CUDA_ERROR(val) check((val), #val, __FILE__, __LINE__)
template <typename T>
void check(T err, const char* const func, const char* const file,
           const int line)
{
    if (err != cudaSuccess)
    {
        std::cerr << "CUDA Runtime Error at: " << file << ":" << line
                  << std::endl;
        std::cerr << cudaGetErrorString(err) << " " << func << std::endl;
        // We don't exit when we encounter CUDA errors in this example.
        // std::exit(EXIT_FAILURE);
    }
}
#endif  // CUDA_CHECK_ERROR

#ifndef CHECK_LAST_CUDA_ERROR
#define CHECK_LAST_CUDA_ERROR() checkLast(__FILE__, __LINE__)
void checkLast(const char* const file, const int line)
{
    cudaError_t err{cudaGetLastError()};
    if (err != cudaSuccess)
    {
        std::cerr << "CUDA Runtime Error at: " << file << ":" << line
                  << std::endl;
        std::cerr << cudaGetErrorString(err) << std::endl;
        // We don't exit when we encounter CUDA errors in this example.
        // std::exit(EXIT_FAILURE);
    }
}
#endif  // CHECK_LAST_CUDA_ERROR


torch::Tensor butterfly_copy_cuda(
  torch::Tensor x,
  torch::Tensor twiddle_factors
){

  uint B = x.size(0);
  uint H = x.size(1);
  uint m = x.size(3);
  uint l_m = x.size(-1);
  dim3 gridDim;
  dim3 blockDim;

  // printf("gridDim.x = %d, gridDim.y = %d\n", gridDim.x, gridDim.y);
  //torch::Tensor out = torch::empty({B, H, m, m, l_m}, x.options());
  torch::Tensor out = torch::empty({B, H, m, m, l_m}, x.options());

  gridDim.x = H;
  gridDim.y = m;

  blockDim.x = 1024;
  blockDim.y = 1;

  //printf("Values of B %d, H %d, m %d, m %d, l_m %d\n", B, H, m, m, l_m);
  CUDA_RT_CALL(cudaFuncSetAttribute(&butterfly_copy_cuda_kernel, cudaFuncAttributeMaxDynamicSharedMemorySize, 164000));
      
  //printf("amoutn of shared memory being allocated is %d\n", 2 * m * l_m * sizeof(half) + 2 * l_m * sizeof(half));

  butterfly_copy_cuda_kernel<<<gridDim, blockDim, 2 * m * l_m * sizeof(half) + 2 * l_m * sizeof(half)>>>(
      static_cast<complex_half_t *>(x.data_ptr()),
      static_cast<complex_half_t *>(twiddle_factors.data_ptr()),
      static_cast<complex_half_t *>(out.data_ptr()),
      B,
      H,
      m,
      l_m,
      x.stride(0),
      x.stride(1),
      x.stride(2),
      x.stride(3),
      x.stride(4),
      x.stride(5),
      twiddle_factors.stride(0),
      twiddle_factors.stride(1),
      twiddle_factors.stride(2),
      twiddle_factors.stride(3),
      twiddle_factors.stride(4),
      twiddle_factors.stride(5),
      out.stride(0),
      out.stride(1),
      out.stride(2),
      out.stride(3),
      out.stride(4)
      );

  CHECK_LAST_CUDA_ERROR();
  return out;
}


torch::Tensor butterfly_transcendental_cuda(
  torch::Tensor x
){

  uint B = x.size(0);
  uint H = x.size(1);
  uint m = x.size(3);
  uint l_m = x.size(-1);
  dim3 gridDim;
  dim3 blockDim;

  // printf("gridDim.x = %d, gridDim.y = %d\n", gridDim.x, gridDim.y);
  torch::Tensor out = torch::empty({B, H, m, m, l_m}, x.options());

  gridDim.x = H;
  gridDim.y = m;

  blockDim.x = 1024;
  blockDim.y = 1;
  
  CUDA_RT_CALL(cudaFuncSetAttribute(&butterfly_transcendental_cuda_kernel, cudaFuncAttributeMaxDynamicSharedMemorySize, 164000));
  butterfly_transcendental_cuda_kernel<<<gridDim, blockDim, 2 * m * l_m * sizeof(half) + 2 * l_m * sizeof(half)>>>(
      static_cast<complex_half_t *>(x.data_ptr()),
      static_cast<complex_half_t *>(out.data_ptr()),
      B,
      H,
      m,
      l_m,
      x.stride(0),
      x.stride(1),
      x.stride(2),
      x.stride(3),
      x.stride(4),
      x.stride(5),
      out.stride(0),
      out.stride(1),
      out.stride(2),
      out.stride(3),
      out.stride(4)
      );

  CHECK_LAST_CUDA_ERROR();
  return out;
}


torch::Tensor const_compute_cuda(
  torch::Tensor x
){

  uint B = x.size(0);
  uint H = x.size(1);
  uint m = x.size(3);
  uint l_m = x.size(-1);
  dim3 gridDim;
  dim3 blockDim;

  // printf("gridDim.x = %d, gridDim.y = %d\n", gridDim.x, gridDim.y);
  torch::Tensor out = torch::empty({B, H, m, m, l_m}, x.options());

  gridDim.x = H;
  gridDim.y = m;
  gridDim.z = m;

  blockDim.x = 1024;
  blockDim.y = 1;
  
  // const_compute_kernel<<<gridDim, blockDim, l_m * sizeof(complex_half_t)>>>(
  //     static_cast<complex_half_t *>(out.data_ptr()),
  //     B,
  //     H,
  //     m,
  //     l_m,
  //     out.stride(0),
  //     out.stride(1),
  //     out.stride(2),
  //     out.stride(3),
  //     out.stride(4)
  //     );

  CHECK_LAST_CUDA_ERROR();
  return out;
}