#include <ATen/ATen.h>
#include <ATen/cuda/CUDAContext.h>
#include <cufft.h>
#include <c10/cuda/CUDACachingAllocator.h>

#include <torch/extension.h>

#include <cuda.h>
#include <cuda_runtime.h>
#include <Python.h>
#include <ATen/Operators.h>
#include <torch/all.h>
#include <torch/library.h>

#include <vector>

// #define CREAL crealf
// #define CIMAG cimagf
// #define CEXP cexpf
#define MATH_PI M_PI

extern "C" {
  /* Creates a dummy empty _C module that can be imported from Python.
     The import from Python will load the .so consisting of this file
     in this extension, so that the TORCH_LIBRARY static initializers
     below are run. */
  PyObject* PyInit__C(void)
  {
      static struct PyModuleDef module_def = {
          PyModuleDef_HEAD_INIT,
          "_C",   /* name of module */
          NULL,   /* module documentation, may be NULL */
          -1,     /* size of per-interpreter state of the module,
                     or -1 if the module keeps state in global variables. */
          NULL,   /* methods */
      };
      return PyModule_Create(&module_def);
  }
}

namespace extension_cpp {
  inline __device__ cuFloatComplex make_zero(cuFloatComplex*) {
      return make_cuFloatComplex(0.0f, 0.0f);
  }
  inline __device__ cuDoubleComplex make_zero(cuDoubleComplex*) {
      return make_cuDoubleComplex(0.0, 0.0);
  }

  inline __device__ cuFloatComplex complex_add(cuFloatComplex a, cuFloatComplex b) {
      return cuCaddf(a, b);
  }
  inline __device__ cuDoubleComplex complex_add(cuDoubleComplex a, cuDoubleComplex b) {
      return cuCadd(a, b);
  }

  inline __device__ cuFloatComplex complex_mul(cuFloatComplex a, cuFloatComplex b) {
      return cuCmulf(a, b);
  }
  inline __device__ cuDoubleComplex complex_mul(cuDoubleComplex a, cuDoubleComplex b) {
      return cuCmul(a, b);
  }

  template<typename scalar_t>
  __global__ void diagmul_matrix(const scalar_t* diag, scalar_t* x, int batch, int seq_len, int _) {
    int b = blockIdx.x * blockDim.x + threadIdx.x;  // `batch` 维度
    int r = blockIdx.y * blockDim.y + threadIdx.y;  // `seq_len` 维度
    int k = blockIdx.z * blockDim.z + threadIdx.z;  // `_` 维度
  
  
    if (b < batch && r < seq_len && k < _) {
      int x_idx  = b * seq_len * _ + r * _ + k;
      int diag_idx = k;
      
      x[x_idx] = x[x_idx] * diag[diag_idx];
    }
  }
  
  at::Tensor mydiagmul_cuda(const at::Tensor& diag, at::Tensor& x) {
    // TORCH_CHECK(diag.dtype() == at::kFloat, "diag must be float32");
    TORCH_CHECK(
      diag.dtype() == at::kFloat || diag.dtype() == at::kDouble,
      "Input must be float or double"
    );
    TORCH_CHECK(
      x.dtype() == at::kFloat || x.dtype() == at::kDouble,
      "Input must be float or double"
    );
    TORCH_CHECK(diag.dtype() == x.dtype());
    TORCH_CHECK(diag.is_cuda(), "diag must be on CUDA device");
    // TORCH_CHECK(x.dtype() == at::kFloat, "x must be float32");
    TORCH_CHECK(x.is_cuda(), "x must be on CUDA device");
    TORCH_CHECK(x.dim() == 3, "Input x must be a 3D tensor");

    int64_t batch = x.size(0);
    int64_t seq_len = x.size(1);
    int64_t _ = x.size(2);
  
    dim3 block_wx(8, 8, 8);
    dim3 grid_wx((batch + block_wx.x - 1) / block_wx.x, 
                  (seq_len + block_wx.y - 1) / block_wx.y, 
                  (_ + block_wx.z - 1) / block_wx.z);
    if (x.dtype() == at::kFloat){
      const float* d_diag = diag.data_ptr<float>();
      float* d_x = x.data_ptr<float>();   
      using scalar_t = float;
      diagmul_matrix<scalar_t><<<grid_wx, block_wx>>>(d_diag, d_x, batch, seq_len, _);
    } else if (x.dtype() == at::kDouble){
      const double* d_diag = diag.data_ptr<double>();
      double* d_x = x.data_ptr<double>();   
      using scalar_t = double;           
      diagmul_matrix<scalar_t><<<grid_wx, block_wx>>>(d_diag, d_x, batch, seq_len, _);
    } 
    cudaDeviceSynchronize();
    return x;
  }


  at::Tensor myfft_cuda(const at::Tensor& x) {
    // TORCH_CHECK(a.sizes() == b.sizes());
    TORCH_CHECK(
      x.dtype() == at::kComplexFloat || x.dtype() == at::kComplexDouble,
      "Input must be complex64 or complex128"
    );
    TORCH_CHECK(x.is_cuda(), "Input must be on CUDA device");
    TORCH_CHECK(x.dim() >= 1, "Input tensor must have at least 1 dimension");

    int64_t block_size = x.size(-1);  // 最后一维大小
    int64_t total_batches = x.numel() / block_size;

    // at::Tensor output = at::empty_like(x);
    cufftHandle plan;

    

    if (x.dtype() == at::kComplexFloat){
      cufftComplex* d_data = reinterpret_cast<cufftComplex*>(x.data_ptr<c10::complex<float>>());
      if (cufftPlan1d(&plan, block_size, CUFFT_C2C, total_batches) != CUFFT_SUCCESS) {
          TORCH_CHECK(false, "CUFFT plan creation failed");
      }

      if (cufftExecC2C(plan, d_data, d_data, CUFFT_FORWARD) != CUFFT_SUCCESS) {
        cufftDestroy(plan);
        TORCH_CHECK(false, "CUFFT execution failed");
      }
    } else if (x.dtype() == at::kComplexDouble){
      cufftDoubleComplex* d_data = reinterpret_cast<cufftDoubleComplex*>(x.data_ptr<c10::complex<double>>());
      if (cufftPlan1d(&plan, block_size, CUFFT_Z2Z, total_batches) != CUFFT_SUCCESS) {
          TORCH_CHECK(false, "CUFFT plan creation failed");
      }

      if (cufftExecZ2Z(plan, d_data, d_data, CUFFT_FORWARD) != CUFFT_SUCCESS) {
        cufftDestroy(plan);
        TORCH_CHECK(false, "CUFFT execution failed");
      }
    }
    cudaDeviceSynchronize();
    // output.copy_(x);
    cufftDestroy(plan);

    return x;
  }

  template<typename complex_t>
  __global__ void mytransfor_matrix(complex_t* x, int _, int rows, int cols, int N) {
    int b = blockIdx.x * blockDim.x + threadIdx.x;  // `_` 维度
    int r = blockIdx.y * blockDim.y + threadIdx.y;  // `rows` 维度
    int c = blockIdx.z * blockDim.z + threadIdx.z;  // `cols` 维度

    if (b < _ && r < rows && c < cols) {
      int x_idx  = b * rows * cols * N + r * cols * N +  c * N;
      int j = 0;
      for (int i = 1; i < N; i++){
        int bit = N >> 1;
        while (j >= bit){
          j -= bit;
          bit >>= 1;
        }
        j += bit;
        if (j > i){
          complex_t tmp = x[x_idx+i];
          x[x_idx+i] = x[x_idx+j];
          x[x_idx+j] = tmp;
        }
      }
    }
  }

  __global__ void myrfft_float(float* x, int _, int rows, int cols, int N) {
    int b = blockIdx.x * blockDim.x + threadIdx.x;  // `_` 维度
    int r = blockIdx.y * blockDim.y + threadIdx.y;  // `rows` 维度
    int c = blockIdx.z * blockDim.z + threadIdx.z;  // `cols` 维度

    int x_idx  = b * rows * cols * N + r * cols * N +  c * N;

    if (b < _ && r < rows && c < cols) {
      for (int L = 2; L <= N; L <<= 1){
        for (int k = 0; k < N; k+=L){
          for (int j = 0; j <= L/4; j++){
            float angle1 = -2 * M_PI * j / L;
            float angle2 = -1 * M_PI * (L-2*j) / L;
            float t1, t2, t3, t4;
            if (j == 0){
              t1 = x[x_idx+k+j]+x[x_idx+k+j+L/2];
              t2 = x[x_idx+k+j]-x[x_idx+k+j+L/2];
              x[x_idx+k+j] = t1;
              x[x_idx+k+j+L/2] = t2;
            } else if (j == L/4){
              t1 = x[x_idx+k+j]+x[x_idx+k+j+L/2]*cosf(angle1);
              t2 = x[x_idx+k+j+L/2]*sinf(angle1);
              x[x_idx+k+j] = t1;
              x[x_idx+k+j+L/2] = t2;
            } else{
              t1 = x[x_idx+k+j]+x[x_idx+k+j+L/2]*cosf(angle1)-x[x_idx+k+L-j]*sinf(angle1);
              t2 = x[x_idx+k+L/2-j]+x[x_idx+k+j+L/2]*sinf(angle1)+x[x_idx+k+L-j]*cosf(angle1);
              t3 = x[x_idx+k+j]+x[x_idx+k+j+L/2]*cosf(angle2)+x[x_idx+k+L-j]*sinf(angle2);
              t4 = -x[x_idx+k+L/2-j]+x[x_idx+k+j+L/2]*sinf(angle2)-x[x_idx+k+L-j]*cosf(angle2);
              x[x_idx+k+j] = t1;
              x[x_idx+k+L/2-j] = t3;
              x[x_idx+k+j+L/2] = t4;
              x[x_idx+k+L-j] = t2;}
          }
        }
      }
    }
  }

  __global__ void myrfft_double(double* x, int _, int rows, int cols, int N) {
    int b = blockIdx.x * blockDim.x + threadIdx.x;  // `_` 维度
    int r = blockIdx.y * blockDim.y + threadIdx.y;  // `rows` 维度
    int c = blockIdx.z * blockDim.z + threadIdx.z;  // `cols` 维度

    int x_idx  = b * rows * cols * N + r * cols * N +  c * N;

    if (b < _ && r < rows && c < cols) {
      for (int L = 2; L <= N; L <<= 1){
        for (int k = 0; k < N; k+=L){
          for (int j = 0; j <= L/4; j++){
            double angle1 = -2 * M_PI * j / L;
            double angle2 = -1 * M_PI * (L-2*j) / L;
            double t1, t2, t3, t4;
            if (j == 0){
              t1 = x[x_idx+k+j]+x[x_idx+k+j+L/2];
              t2 = x[x_idx+k+j]-x[x_idx+k+j+L/2];
              x[x_idx+k+j] = t1;
              x[x_idx+k+j+L/2] = t2;
            } else if (j == L/4){
              t1 = x[x_idx+k+j]+x[x_idx+k+j+L/2]*cos(angle1);
              t2 = x[x_idx+k+j+L/2]*sin(angle1);
              x[x_idx+k+j] = t1;
              x[x_idx+k+j+L/2] = t2;
            } else{
              t1 = x[x_idx+k+j]+x[x_idx+k+j+L/2]*cos(angle1)-x[x_idx+k+L-j]*sin(angle1);
              t2 = x[x_idx+k+L/2-j]+x[x_idx+k+j+L/2]*sin(angle1)+x[x_idx+k+L-j]*cos(angle1);
              t3 = x[x_idx+k+j]+x[x_idx+k+j+L/2]*cos(angle2)+x[x_idx+k+L-j]*sin(angle2);
              t4 = -x[x_idx+k+L/2-j]+x[x_idx+k+j+L/2]*sin(angle2)-x[x_idx+k+L-j]*cos(angle2);
              x[x_idx+k+j] = t1;
              x[x_idx+k+L/2-j] = t3;
              x[x_idx+k+j+L/2] = t4;
              x[x_idx+k+L-j] = t2;}
          }
        }
      }
    }
  }

  at::Tensor myrfft_cuda(at::Tensor& x) {
    TORCH_CHECK(
      x.dtype() == at::kFloat || x.dtype() == at::kDouble,
      "Input must be float or double"
    );
    TORCH_CHECK(x.is_cuda(), "Input must be on CUDA device");
    TORCH_CHECK(x.dim() == 4, "Input must be 4D tensor");

    int64_t _ = x.size(0);
    int64_t r = x.size(1); 
    int64_t c = x.size(2);
    int64_t k = x.size(3);

    dim3 block_wx(8, 8, 8);
    dim3 grid_wx((_ + block_wx.x - 1) / block_wx.x, 
                  (r + block_wx.y - 1) / block_wx.y, 
                  (c + block_wx.z - 1) / block_wx.z);
    
    if (x.dtype() == at::kFloat){
      float* d_data = x.data_ptr<float>();
      using complex_t = float;
      mytransfor_matrix<complex_t><<<grid_wx, block_wx>>>(d_data, _, r, c, k);
      myrfft_float<<<grid_wx, block_wx>>>(d_data, _, r, c, k);
    } else if (x.dtype() == at::kDouble){
      double* d_data = x.data_ptr<double>();
      using complex_t = double;
      mytransfor_matrix<complex_t><<<grid_wx, block_wx>>>(d_data, _, r, c, k);
      myrfft_double<<<grid_wx, block_wx>>>(d_data, _, r, c, k);
    }
    cudaDeviceSynchronize();
    return x;
  }

  __device__ uint32_t reverse_bits(uint32_t x)
  {
      x = ((x & 0xaaaaaaaa) >> 1) | ((x & 0x55555555) << 1);
      x = ((x & 0xcccccccc) >> 2) | ((x & 0x33333333) << 2);
      x = ((x & 0xf0f0f0f0) >> 4) | ((x & 0x0f0f0f0f) << 4);
      x = ((x & 0xff00ff00) >> 8) | ((x & 0x00ff00ff) << 8);
      return (x >> 16) | (x << 16);
  }

  template<typename real_t>
  __global__ void bit_reverse_permute_kernel(real_t* x, int N, int log2N) {
    int i = blockIdx.x * blockDim.x + threadIdx.x;
    if (i >= N)
        return;
    uint32_t j = reverse_bits(i) >> (32 - log2N);
    if (j > i)
    {
        real_t tmp = x[i];
        x[i] = x[j];
        x[j] = tmp;
    }
  }

  template<typename real_t>
  __global__ void fft_stage_kernel(real_t *x, int N, int L, int num_work_items)
  {
    int tid = blockIdx.x * blockDim.x + threadIdx.x;

    if (tid >= num_work_items)
        return;

    int group_id; // which FFT block.
    int j;        // butterfly index.

    if (L > 4)
    {
        group_id = tid / (L / 4 + 1);
        j = tid % (L / 4 + 1);
    }
    else
    {
        group_id = tid / (L / 2);
        j = tid % (L / 2);
    }

    int k = group_id * L; // starting offset.

    if (L > 4)
    {
        if (j == 0)
        {
            real_t x1 = x[k + j] + x[k + j + L / 2];
            real_t x2 = x[k + j] - x[k + j + L / 2];
            x[k + j] = x1;
            x[k + j + L / 2] = x2;
        }
        else if (j == L / 4)
        {
            real_t theta = -2.0 * MATH_PI * j / L;
            real_t c = cos(theta);
            real_t s = sin(theta);

            real_t real = x[k + j] + c * x[k + j + L / 2];
            real_t imag = s * x[k + j + L / 2];

            x[k + j] = real;
            x[k + j + L / 2] = imag;
        }
        else
        {
            real_t a = x[k + j];
            real_t b = x[k + L / 2 - j];
            real_t c_val = x[k + j + L / 2];
            real_t d = x[k + L - j];

            real_t theta1 = -2.0 * MATH_PI * j / L;
            real_t cos1 = cos(theta1);
            real_t sin1 = sin(theta1);

            real_t add1_real = a + (cos1 * c_val - sin1 * d);
            real_t add1_imag = b + (sin1 * c_val + cos1 * d);

            real_t theta2 = -MATH_PI * (1.0 - 2.0 * j / L);
            real_t cos2 = cos(theta2);
            real_t sin2 = sin(theta2);

            real_t add2_real = a + (cos2 * c_val + sin2 * d);
            real_t add2_imag = -b + (sin2 * c_val - cos2 * d);

            x[k + j] = add1_real;
            x[k + L / 2 - j] = add2_real;
            x[k + j + L / 2] = add2_imag;
            x[k + L - j] = add1_imag;
        }
    }
    else
    {
        if (j == 0)
        {
            real_t x1 = x[k + j] + x[k + j + L / 2];
            real_t x2 = x[k + j] - x[k + j + L / 2];
            x[k + j] = x1;
            x[k + j + L / 2] = x2;
        }
        else
        {
            real_t theta = -2.0 * MATH_PI * j / L;
            real_t c = cos(theta);
            real_t s = sin(theta);

            real_t real = x[k + j] + c * x[k + j + L / 2];
            real_t imag = s * x[k + j + L / 2];

            x[k + j] = real;
            x[k + j + L / 2] = imag;
        }
    }
  }

  at::Tensor myrffte_cuda(at::Tensor& x) {
    TORCH_CHECK(
      x.dtype() == at::kFloat || x.dtype() == at::kDouble,
      "Input must be float or double"
    );
    TORCH_CHECK(x.is_cuda(), "Input must be on CUDA device");
    // TORCH_CHECK(x.dim() == 4, "Input must be 4D tensor");

    int64_t N = x.size(0);

    int num_steps = (int)log2(N);

    int threads = 256;
    int blocks = (N + threads - 1) / threads;
    
    if (x.dtype() == at::kFloat){
      float* d_data = x.data_ptr<float>();
      using real_t = float;
      bit_reverse_permute_kernel<real_t><<<blocks, threads>>>(d_data, N, num_steps);
      // printf("num_steps: %d, N: %d\n", num_steps, N);
      // checkLastCudaError(__LINE__);
      cudaDeviceSynchronize();
      for (int s = 1; s <= num_steps; s++){
          int L = 1 << s;
          int num_work_items;
          if (L > 4)
              num_work_items = (N / L) * (L / 4 + 1);
          else
              num_work_items = (N / L) * (L / 2);
          int blocks = (num_work_items + threads - 1) / threads;
          fft_stage_kernel<real_t><<<blocks, threads>>>(d_data, N, L, num_work_items);
          // checkLastCudaError(__LINE__);
          cudaDeviceSynchronize();
      }
    } else if (x.dtype() == at::kDouble){
      double* d_data = x.data_ptr<double>();
      using real_t = double;
      bit_reverse_permute_kernel<real_t><<<blocks, threads>>>(d_data, N, num_steps);
      // checkLastCudaError(__LINE__);
      cudaDeviceSynchronize();
      for (int s = 1; s <= num_steps; s++){
          int L = 1 << s;
          int num_work_items;
          if (L > 4)
              num_work_items = (N / L) * (L / 4 + 1);
          else
              num_work_items = (N / L) * (L / 2);
          int blocks = (num_work_items + threads - 1) / threads;

          fft_stage_kernel<real_t><<<blocks, threads>>>(d_data, N, L, num_work_items);
          // checkLastCudaError(__LINE__);
          cudaDeviceSynchronize();
      }
    }
    return x;
  }

  __global__ void myirfft_float(float* x, int _, int rows, int cols, int N) {
    int b = blockIdx.x * blockDim.x + threadIdx.x;  // `_` 维度
    int r = blockIdx.y * blockDim.y + threadIdx.y;  // `rows` 维度
    int c = blockIdx.z * blockDim.z + threadIdx.z;  // `cols` 维度

    int x_idx  = b * rows * cols * N + r * cols * N +  c * N;

    if (b < _ && r < rows && c < cols) {
      for (int L = N; L > 1; L >>= 1){
        for (int k = 0; k < N; k+=L){
          for (int j = 0; j <= L/4; j++){
            float angle = M_PI * 2 * j / L;
            float t1, t2, t3, t4;
            if (j == 0){
              t1 = (x[x_idx+k+j]+x[x_idx+k+j+L/2])/2;
              t2 = (x[x_idx+k+j]-x[x_idx+k+j+L/2])/2;
              x[x_idx+k+j] = t1;
              x[x_idx+k+j+L/2] = t2;
            }else if (j == L/4){
              t1 = x[x_idx+k+j];
              t2 = -x[x_idx+k+j+L/2]*sinf(angle);
              x[x_idx+k+j] = t1;
              x[x_idx+k+j+L/2] = t2;
            } else{
              t1 = (x[x_idx+k+j]+x[x_idx+k+L/2-j])/2;
              t2 = (x[x_idx+k+L-j]-x[x_idx+k+j+L/2])/2;
              t3 = ((x[x_idx+k+j]-x[x_idx+k+L/2-j])*cosf(angle)-(x[x_idx+k+L-j]+x[x_idx+k+j+L/2])*sinf(angle))/2;
              t4 = ((x[x_idx+k+j]-x[x_idx+k+L/2-j])*sinf(angle)+(x[x_idx+k+L-j]+x[x_idx+k+j+L/2])*cosf(angle))/2;
              x[x_idx+k+j] = t1;
              x[x_idx+k+L/2-j] = t2;
              x[x_idx+k+j+L/2] = t3;
              x[x_idx+k+L-j] = t4;}
          }
        }
      }
    }
  }

  __global__ void myirfft_double(double* x, int _, int rows, int cols, int N) {
    int b = blockIdx.x * blockDim.x + threadIdx.x;  // `_` 维度
    int r = blockIdx.y * blockDim.y + threadIdx.y;  // `rows` 维度
    int c = blockIdx.z * blockDim.z + threadIdx.z;  // `cols` 维度

    int x_idx  = b * rows * cols * N + r * cols * N +  c * N;

    if (b < _ && r < rows && c < cols) {
      for (int L = N; L > 1; L >>= 1){
        for (int k = 0; k < N; k+=L){
          for (int j = 0; j <= L/4; j++){
            double angle = M_PI * 2 * j / L;
            double t1, t2, t3, t4;
            if (j == 0){
              t1 = (x[x_idx+k+j]+x[x_idx+k+j+L/2])/2;
              t2 = (x[x_idx+k+j]-x[x_idx+k+j+L/2])/2;
              x[x_idx+k+j] = t1;
              x[x_idx+k+j+L/2] = t2;
            }else if (j == L/4){
              t1 = x[x_idx+k+j];
              t2 = -x[x_idx+k+j+L/2]*sin(angle);
              x[x_idx+k+j] = t1;
              x[x_idx+k+j+L/2] = t2;
            } else{
              t1 = (x[x_idx+k+j]+x[x_idx+k+L/2-j])/2;
              t2 = (x[x_idx+k+L-j]-x[x_idx+k+j+L/2])/2;
              t3 = ((x[x_idx+k+j]-x[x_idx+k+L/2-j])*cos(angle)-(x[x_idx+k+L-j]+x[x_idx+k+j+L/2])*sin(angle))/2;
              t4 = ((x[x_idx+k+j]-x[x_idx+k+L/2-j])*sin(angle)+(x[x_idx+k+L-j]+x[x_idx+k+j+L/2])*cos(angle))/2;
              x[x_idx+k+j] = t1;
              x[x_idx+k+L/2-j] = t2;
              x[x_idx+k+j+L/2] = t3;
              x[x_idx+k+L-j] = t4;}
          }
        }
      }
    }
  }

  at::Tensor myirfft_cuda(at::Tensor& x) {
    TORCH_CHECK(
      x.dtype() == at::kFloat || x.dtype() == at::kDouble,
      "Input must be float or double"
    );
    TORCH_CHECK(x.is_cuda(), "Input must be on CUDA device");
    TORCH_CHECK(x.dim() == 4, "Input must be 4D tensor");

    int64_t _ = x.size(0);
    int64_t r = x.size(1); 
    int64_t c = x.size(2);
    int64_t k = x.size(3);

    dim3 block_wx(8, 8, 8);
    dim3 grid_wx((_ + block_wx.x - 1) / block_wx.x, 
                  (r + block_wx.y - 1) / block_wx.y, 
                  (c + block_wx.z - 1) / block_wx.z);
    
    if (x.dtype() == at::kFloat){
      float* d_data = x.data_ptr<float>();
      using complex_t = float;
      myirfft_float<<<grid_wx, block_wx>>>(d_data, _, r, c, k);
      mytransfor_matrix<complex_t><<<grid_wx, block_wx>>>(d_data, _, r, c, k);
    } else if (x.dtype() == at::kDouble){
      double* d_data = x.data_ptr<double>();
      using complex_t = double;
      myirfft_double<<<grid_wx, block_wx>>>(d_data, _, r, c, k);
      mytransfor_matrix<complex_t><<<grid_wx, block_wx>>>(d_data, _, r, c, k);
    }
    cudaDeviceSynchronize();
    return x;
  }

  template<typename complex_t>
  __global__ void mul_matrix(const complex_t* fft_w, complex_t* fft_x, int _, int rows, int cols, int K) {
    int b = blockIdx.x * blockDim.x + threadIdx.x;  // `_` 维度
    int r = blockIdx.y * blockDim.y + threadIdx.y;  // `rows` 维度
    int k = blockIdx.z * blockDim.z + threadIdx.z;  // `K` 维度
    // cuFloatComplex acc = make_cuFloatComplex(0.0f, 0.0f);


    if (b < _ && r < rows && k < K) {
      complex_t acc = make_zero((complex_t*)nullptr);
      for (int c = 0; c < cols; c++) {
          int w_idx  = r * cols * K + c * K + k;
          int x_idx  = b * cols * K + c * K + k;

          acc = complex_add(acc, complex_mul(fft_w[w_idx], fft_x[x_idx]));
      }
      int y_idx = b * rows * K + r * K + k;
      fft_x[y_idx] = acc;
    }
  }

  at::Tensor mymul_cuda(const at::Tensor& w, at::Tensor& x) {
    TORCH_CHECK(
      w.dtype() == at::kComplexFloat || w.dtype() == at::kComplexDouble,
      "Input must be complex64 or complex128"
    );
    TORCH_CHECK(w.is_cuda(), "Input must be on CUDA device");
    TORCH_CHECK(w.dim() == 4, "Input must be 4D tensor");

    TORCH_CHECK(
      x.dtype() == at::kComplexFloat || x.dtype() == at::kComplexDouble,
      "Input must be complex64 or complex128"
    );
    TORCH_CHECK(x.is_cuda(), "Input must be on CUDA device");
    TORCH_CHECK(x.dim() == 4, "Input must be 4D tensor");
    TORCH_CHECK(x.dtype() == w.dtype());
    

    int64_t _ = x.size(0);
    int64_t r = w.size(1); 
    int64_t c = x.size(2);
    int64_t k = x.size(3);

    dim3 block_wx(8, 8, 8);
    dim3 grid_wx((_ + block_wx.x - 1) / block_wx.x, 
                  (r + block_wx.y - 1) / block_wx.y, 
                  (k + block_wx.z - 1) / block_wx.z);
    
    if (x.dtype() == at::kComplexFloat){
      cufftComplex* d_fftw = reinterpret_cast<cufftComplex*>(w.data_ptr<c10::complex<float>>());
      cufftComplex* d_fftx = reinterpret_cast<cufftComplex*>(x.data_ptr<c10::complex<float>>());
      using complex_t = cuFloatComplex;
      mul_matrix<complex_t><<<grid_wx, block_wx>>>(d_fftw, d_fftx, _, r, c, k);
    } else if (x.dtype() == at::kComplexDouble){
      cufftDoubleComplex* d_fftw = reinterpret_cast<cufftDoubleComplex*>(w.data_ptr<c10::complex<double>>());
      cufftDoubleComplex* d_fftx = reinterpret_cast<cufftDoubleComplex*>(x.data_ptr<c10::complex<double>>());
      using complex_t = cuDoubleComplex;
      mul_matrix<complex_t><<<grid_wx, block_wx>>>(d_fftw, d_fftx, _, r, c, k);
    }
    cudaDeviceSynchronize();
    return x;
  }

  template<typename complex_t>
  __global__ void rmul_matrix(const complex_t* fft_w, complex_t* fft_x, int _, int rows, int cols, int N) {
    int b = blockIdx.x * blockDim.x + threadIdx.x;  // `_` 维度
    int r = blockIdx.y * blockDim.y + threadIdx.y;  // `rows` 维度
    int n = blockIdx.z * blockDim.z + threadIdx.z;  // `N` 维度
    // cuFloatComplex acc = make_cuFloatComplex(0.0f, 0.0f);


    if (b < _ && r < rows && n < N) {
      int w_idx, x_idx, y_idx;
      if (n==0 || n == N/2){
        complex_t acc0 = 0.0f;
        for (int c = 0; c < cols; c++) {
          w_idx  = r * cols * N + c * N;
          x_idx  = b * cols * N + c * N;
          acc0 = acc0 +  fft_w[w_idx+n]*fft_x[x_idx+n];
        }
        y_idx = b * rows * N + r * N;
        fft_x[y_idx+n] = acc0;
      } else if (n>0 && n<N/2){
        complex_t acc1 = 0.0f;
        complex_t acc2 = 0.0f;
        for (int c = 0; c < cols; c++) {
          w_idx  = r * cols * N + c * N;
          x_idx  = b * cols * N + c * N;
          acc1 = acc1 +  fft_w[w_idx+n]*fft_x[x_idx+n]-fft_w[w_idx+N-n]*fft_x[x_idx+N-n];
          acc2 = acc2 +  fft_w[w_idx+n]*fft_x[x_idx+N-n]+fft_w[w_idx+N-n]*fft_x[x_idx+n];
        }
        y_idx = b * rows * N + r * N;
        fft_x[y_idx+n] = acc1;
        fft_x[y_idx+N-n] = acc2;
      }
    }
  }

  at::Tensor myrmul_cuda(const at::Tensor& w, at::Tensor& x) {
    TORCH_CHECK(
      w.dtype() == at::kFloat || w.dtype() == at::kDouble,
      "Input must be float or double"
    );
    TORCH_CHECK(w.is_cuda(), "Input must be on CUDA device");
    TORCH_CHECK(w.dim() == 4, "Input must be 4D tensor");

    TORCH_CHECK(
      x.dtype() == at::kFloat || x.dtype() == at::kDouble,
      "Input must be float or double"
    );
    TORCH_CHECK(x.is_cuda(), "Input must be on CUDA device");
    TORCH_CHECK(x.dim() == 4, "Input must be 4D tensor");
    TORCH_CHECK(x.dtype() == w.dtype());
    

    int64_t _ = x.size(0);
    int64_t r = w.size(1); 
    int64_t c = x.size(2);
    int64_t k = x.size(3);

    dim3 block_wx(8, 8, 8);
    dim3 grid_wx((_ + block_wx.x - 1) / block_wx.x, 
                  (r + block_wx.y - 1) / block_wx.y, 
                  (k + block_wx.z - 1) / block_wx.z);
    
    if (x.dtype() == at::kFloat){
      float* d_x = x.data_ptr<float>();
      float* d_w = w.data_ptr<float>();
      using complex_t = float;
      rmul_matrix<complex_t><<<grid_wx, block_wx>>>(d_w, d_x, _, r, c, k);
    } else if (x.dtype() == at::kDouble){
      double* d_x = x.data_ptr<double>();
      double* d_w = w.data_ptr<double>();
      using complex_t = double;
      rmul_matrix<complex_t><<<grid_wx, block_wx>>>(d_w, d_x, _, r, c, k);
    }
    cudaDeviceSynchronize();
    return x;
  }

  at::Tensor myifft_cuda(const at::Tensor& x) {
    // TORCH_CHECK(a.sizes() == b.sizes());
    TORCH_CHECK(
      x.dtype() == at::kComplexFloat || x.dtype() == at::kComplexDouble,
      "Input must be complex64 or complex128"
    );
    TORCH_CHECK(x.is_cuda(), "Input must be on CUDA device");
    TORCH_CHECK(x.dim() >= 1, "Input tensor must have at least 1 dimension");

    int64_t block_size = x.size(-1);  // 最后一维大小
    int64_t total_batches = x.numel() / block_size;

    // at::Tensor output = at::empty_like(x);
    cufftHandle plan;
    if (x.dtype() == at::kComplexFloat){
      cufftComplex* d_data = reinterpret_cast<cufftComplex*>(x.data_ptr<c10::complex<float>>());
      if (cufftPlan1d(&plan, block_size, CUFFT_C2C, total_batches) != CUFFT_SUCCESS) {
          TORCH_CHECK(false, "CUIFFT plan creation failed");
      }

      if (cufftExecC2C(plan, d_data, d_data, CUFFT_INVERSE) != CUFFT_SUCCESS) {
        cufftDestroy(plan);
        TORCH_CHECK(false, "CUIFFT execution failed");
      }
    } else if (x.dtype() == at::kComplexDouble){
      cufftDoubleComplex* d_data = reinterpret_cast<cufftDoubleComplex*>(x.data_ptr<c10::complex<double>>());
      if (cufftPlan1d(&plan, block_size, CUFFT_Z2Z, total_batches) != CUFFT_SUCCESS) {
          TORCH_CHECK(false, "CUIFFT plan creation failed");
      }

      if (cufftExecZ2Z(plan, d_data, d_data, CUFFT_INVERSE) != CUFFT_SUCCESS) {
        cufftDestroy(plan);
        TORCH_CHECK(false, "CUIFFT execution failed");
      }
    }
    cudaDeviceSynchronize();
    // output.copy_(x);
    cufftDestroy(plan);

    return x;
  }

  
  template<typename complex_t>
  __global__ void conj_matrix(complex_t* x, size_t size) {
    int idx = blockIdx.x * blockDim.x + threadIdx.x;
    if (idx < size) {
      x[idx] = thrust::conj(x[idx]);
    }
  }

  at::Tensor myconj_cuda(at::Tensor& x) {
    TORCH_CHECK(
      x.dtype() == at::kComplexFloat || x.dtype() == at::kComplexDouble,
      "Input must be complex64 or complex128"
    );
    TORCH_CHECK(x.is_cuda(), "Input must be on CUDA device");   

    size_t size = x.numel();
    const int threads = 256;
    const int blocks = (size + threads - 1) / threads; 
   
    if (x.dtype() == at::kComplexFloat){
      auto* d_data = reinterpret_cast<thrust::complex<float>*>(x.data_ptr<c10::complex<float>>());
      conj_matrix<thrust::complex<float>><<<blocks, threads>>>(d_data, size);
    } else if (x.dtype() == at::kComplexDouble){
      auto* d_data = reinterpret_cast<thrust::complex<double>*>(x.data_ptr<c10::complex<double>>());
      conj_matrix<thrust::complex<double>><<<blocks, threads>>>(d_data, size);
    }
    cudaDeviceSynchronize();
    return x;
  }

  template<typename complex_t>
  __global__ void rgimul_matrix(const complex_t* fft_w, complex_t* fft_o, int _, int rows, int cols, int N) {
    int b = blockIdx.x * blockDim.x + threadIdx.x;  // `_` 维度
    int c = blockIdx.y * blockDim.y + threadIdx.y;  // `cols` 维度
    int n = blockIdx.z * blockDim.z + threadIdx.z;  // `N` 维度
    // cuFloatComplex acc = make_cuFloatComplex(0.0f, 0.0f);


    if (b < _ && c < cols && n < N) {
      int w_idx, o_idx, y_idx;
      if (n==0 || n == N/2){
        complex_t acc0 = 0.0f;
        for (int r = 0; r < rows; r++) {
          w_idx  = r * cols * N + c * N;
          o_idx  = b * rows * N + r * N;
          acc0 = acc0 +  fft_w[w_idx+n]*fft_o[o_idx+n];
        }
        y_idx = b * cols * N + c * N;
        fft_o[y_idx+n] = acc0;
      } else if (n>0 && n<N/2){
        complex_t acc1 = 0.0f;
        complex_t acc2 = 0.0f;
        for (int r = 0; r < rows; r++) {
          w_idx  = r * cols * N + c * N;
          o_idx  = b * rows * N + r * N;
          acc1 = acc1 +  fft_w[w_idx+n]*fft_o[o_idx+n]+fft_w[w_idx+N-n]*fft_o[o_idx+N-n];
          acc2 = acc2 +  fft_w[w_idx+n]*fft_o[o_idx+N-n]-fft_w[w_idx+N-n]*fft_o[o_idx+n];
        }
        y_idx = b * cols * N + c * N;
        fft_o[y_idx+n] = acc1;
        fft_o[y_idx+N-n] = acc2;
      }
    }
  }

  at::Tensor myrgimul_cuda(const at::Tensor& w, at::Tensor& o) {
    TORCH_CHECK(
      w.dtype() == at::kFloat || w.dtype() == at::kDouble,
      "Input must be float or double"
    );
    TORCH_CHECK(w.is_cuda(), "Input must be on CUDA device");
    TORCH_CHECK(w.dim() == 4, "Input must be 4D tensor");

    TORCH_CHECK(
      o.dtype() == at::kFloat || o.dtype() == at::kDouble,
      "Input must be float or double"
    );
    TORCH_CHECK(o.is_cuda(), "Input must be on CUDA device");
    TORCH_CHECK(o.dim() == 4, "Input must be 4D tensor");
    TORCH_CHECK(o.dtype() == w.dtype());
    

    int64_t _ = o.size(0);
    int64_t r = o.size(1); 
    int64_t c = w.size(2);
    int64_t k = o.size(3);

    dim3 block_wx(8, 8, 8);
    dim3 grid_wx((_ + block_wx.x - 1) / block_wx.x, 
                  (c + block_wx.y - 1) / block_wx.y, 
                  (k + block_wx.z - 1) / block_wx.z);
    
    if (o.dtype() == at::kFloat){
      float* d_o = o.data_ptr<float>();
      float* d_w = w.data_ptr<float>();
      using complex_t = float;
      rgimul_matrix<complex_t><<<grid_wx, block_wx>>>(d_w, d_o, _, r, c, k);
    } else if (o.dtype() == at::kDouble){
      double* d_o = o.data_ptr<double>();
      double* d_w = w.data_ptr<double>();
      using complex_t = double;
      rgimul_matrix<complex_t><<<grid_wx, block_wx>>>(d_w, d_o, _, r, c, k);
    }
    cudaDeviceSynchronize();
    return o;
  }

  template<typename complex_t>
  __global__ void rgwmul_matrix(const complex_t* fft_x, const complex_t* fft_o, int _, int rows, int cols, int N, complex_t* res) {
    int r = blockIdx.x * blockDim.x + threadIdx.x;  // `_` 维度
    int c = blockIdx.y * blockDim.y + threadIdx.y;  // `rows` 维度
    int n = blockIdx.z * blockDim.z + threadIdx.z;  // `cols` 维度
    // cuFloatComplex acc = make_cuFloatComplex(0.0f, 0.0f);


    if (r < rows && c < cols && n < N) {
      int x_idx, o_idx, y_idx;
      if (n==0 || n == N/2){
        complex_t acc0 = 0.0f;
        for(int b = 0; b < _; b++){
          x_idx  = b * cols * N + c * N;
          o_idx  = b * rows * N + r * N;
          acc0 = acc0 +  fft_x[x_idx+n]*fft_o[o_idx+n];
        }
        y_idx = r * cols * N + c * N;
        res[y_idx+n] = acc0;
      } else if (n>0 && n<N/2){
        complex_t acc1 = 0.0f;
        complex_t acc2 = 0.0f;
        for(int b = 0; b < _; b++){
          x_idx  = b * cols * N + c * N;
          o_idx  = b * rows * N + r * N;
          acc1 = acc1 + fft_x[x_idx+n]*fft_o[o_idx+n]+fft_x[x_idx+N-n]*fft_o[o_idx+N-n];
          acc2 = acc2 + fft_x[x_idx+n]*fft_o[o_idx+N-n]-fft_x[x_idx+N-n]*fft_o[o_idx+n];
        }
        y_idx = r * cols * N + c * N;
        res[y_idx+n] = acc1;
        res[y_idx+N-n] = acc2;
      }
    }
  }

  at::Tensor myrgwmul_cuda(const at::Tensor& x, const at::Tensor& o) {
    TORCH_CHECK(
      x.dtype() == at::kFloat || x.dtype() == at::kDouble,
      "Input must be float or double"
    );
    TORCH_CHECK(x.is_cuda(), "Input must be on CUDA device");
    TORCH_CHECK(x.dim() == 4, "Input must be 4D tensor");

    TORCH_CHECK(
      o.dtype() == at::kFloat || o.dtype() == at::kDouble,
      "Input must be float or double"
    );
    TORCH_CHECK(o.is_cuda(), "Input must be on CUDA device");
    TORCH_CHECK(o.dim() == 4, "Input must be 4D tensor");
    TORCH_CHECK(o.dtype() == x.dtype());
    

    int64_t _ = o.size(0);
    int64_t r = o.size(1); 
    int64_t c = x.size(2);
    int64_t k = o.size(3);

    at::Tensor res = at::zeros({1, r, c, k}, x.options());

    dim3 block_wx(8, 8, 8);
    dim3 grid_wx((r + block_wx.x - 1) / block_wx.x, 
                  (c + block_wx.y - 1) / block_wx.y, 
                  (k + block_wx.z - 1) / block_wx.z);
    
    if (o.dtype() == at::kFloat){
      float* d_o = o.data_ptr<float>();
      float* d_x = x.data_ptr<float>();
      float* d_res = res.data_ptr<float>();
      using complex_t = float;
      rgwmul_matrix<complex_t><<<grid_wx, block_wx>>>(d_x, d_o, _, r, c, k, d_res);
    } else if (o.dtype() == at::kDouble){
      double* d_o = o.data_ptr<double>();
      double* d_x = x.data_ptr<double>();
      double* d_res = res.data_ptr<double>();
      using complex_t = double;
      rgwmul_matrix<complex_t><<<grid_wx, block_wx>>>(d_x, d_o, _, r, c, k, d_res);
    }
    cudaDeviceSynchronize();
    return res;
  }

  // Defines the operators
  TORCH_LIBRARY(extension_cpp, m) {
    m.def("mydiagmul(Tensor a, Tensor b) -> Tensor");
    m.def("myfft(Tensor a) -> Tensor");
    m.def("myrfft(Tensor a) -> Tensor");
    m.def("myrffte(Tensor a) -> Tensor");
    m.def("mymul(Tensor a, Tensor b) -> Tensor");
    m.def("myrmul(Tensor a, Tensor b) -> Tensor");
    m.def("myrgimul(Tensor a, Tensor b) -> Tensor");
    m.def("myrgwmul(Tensor a, Tensor b) -> Tensor");
    m.def("myifft(Tensor a) -> Tensor");
    m.def("myirfft(Tensor a) -> Tensor");
    m.def("myconj(Tensor a) -> Tensor");
  }

  // Registers CUDA implementations 
  TORCH_LIBRARY_IMPL(extension_cpp, CUDA, m) {
    m.impl("mydiagmul", &mydiagmul_cuda);
    m.impl("myfft", &myfft_cuda);
    m.impl("myrfft", &myrfft_cuda);
    m.impl("myrffte", &myrffte_cuda);
    m.impl("mymul", &mymul_cuda);
    m.impl("myrmul", &myrmul_cuda);
    m.impl("myrgimul", &myrgimul_cuda);
    m.impl("myrgwmul", &myrgwmul_cuda);
    m.impl("myifft", &myifft_cuda);
    m.impl("myirfft", &myirfft_cuda);
    m.impl("myconj", &myconj_cuda);
  }

}
