#include <torch/extension.h>

#include <vector>
#include <stdio.h>
#include <mma.h>
#include <cuda_fp16.h>
using namespace nvcuda;

using complex_half_t = typename c10::complex<at::Half>;

__global__ void butterfly_copy_cuda_kernel(
    const c10::complex<at::Half> *__restrict__ x,
    const c10::complex<at::Half> *__restrict__ twiddle_factors_fft,
    c10::complex<at::Half> *__restrict__ out,
    uint B,
    uint N,
    uint m,
    uint l_m,
    uint x_stride_0,
    uint x_stride_1,
    uint x_stride_2,
    uint x_stride_3,
    uint x_stride_4,
    uint x_stride_5,
    uint twiddle_factors_fft_stride_0,
    uint twiddle_factors_fft_stride_1,
    uint twiddle_factors_fft_stride_2,
    uint twiddle_factors_fft_stride_3,
    uint twiddle_factors_fft_stride_4,
    uint twiddle_factors_fft_stride_5,
    uint out_stride_0,
    uint out_stride_1,
    uint out_stride_2,
    uint out_stride_3,
    uint out_stride_4
)
{


    extern __shared__ complex_half_t temp[];
    complex_half_t* scratch = &temp[m * l_m];

    for(int b = threadIdx.y; b < B; b++){
        for(int k = 0; k < m; k++){
            for(int j = threadIdx.x; j <  l_m; j+=blockDim.x){
                if(threadIdx.x < l_m){
                    temp[k * l_m + j] = x[b * x_stride_0 + blockIdx.x * x_stride_1 + 0 * x_stride_2 + blockIdx.y * x_stride_3 + k * x_stride_4 + j * x_stride_5];
                }
            }
        }
        __syncthreads();

        for(int c=0; c < m; c++){
            for(int j = threadIdx.x; j <  l_m; j+=blockDim.x){
                scratch[j] = temp[j]
                        * twiddle_factors_fft[0 * twiddle_factors_fft_stride_0 + 0 * twiddle_factors_fft_stride_1 + c * twiddle_factors_fft_stride_2 + blockIdx.y * twiddle_factors_fft_stride_3 + 0 * twiddle_factors_fft_stride_4 + j * twiddle_factors_fft_stride_5]; 
                for(int k = 1; k < m; k++){
                    if(threadIdx.x < l_m)
                        scratch[j] += temp[k * l_m + j]
                        * twiddle_factors_fft[0 * twiddle_factors_fft_stride_0 + 0 * twiddle_factors_fft_stride_1 + c * twiddle_factors_fft_stride_2 + blockIdx.y * twiddle_factors_fft_stride_3 + k * twiddle_factors_fft_stride_4 + j * twiddle_factors_fft_stride_5]; 
                }
            }


            __syncthreads();

            for(int j = threadIdx.x; j < l_m; j+=blockDim.x){
                if(threadIdx.x < l_m)
                    out[b * out_stride_0 + blockIdx.x * out_stride_1 + c * out_stride_2 + blockIdx.y * out_stride_3 + j * out_stride_4] = scratch[j];
            }
        }
        
    }
}

__global__ void butterfly_transcendental_cuda_kernel(
    const c10::complex<at::Half> *__restrict__ x,
    c10::complex<at::Half> *__restrict__ out,
    uint B,
    uint N,
    uint m,
    uint l_m,
    uint x_stride_0,
    uint x_stride_1,
    uint x_stride_2,
    uint x_stride_3,
    uint x_stride_4,
    uint x_stride_5,
    uint out_stride_0,
    uint out_stride_1,
    uint out_stride_2,
    uint out_stride_3,
    uint out_stride_4
)
{

    // const __half2 half2_2_pi = __halves2half2(__float2half(6.28318530718), __float2half(6.28318530718));
    // extern __shared__ complex_half_t temp[];

    // for(int b = threadIdx.y; b < B; b++){
    //     for(int j = threadIdx.x; j < l_m; j+=blockDim.x){
    //         temp[j] = complex_half_t(0.0f, 0.0f);
    //     }

    //     __syncthreads();


    //     for(int k = 0; k < m; k++){
    //         for(int j = 2 * threadIdx.x; j <  l_m; j+= 2 * blockDim.x){
    //             __half2 num1 = __halves2half2(__float2half(j * 1.0), __float2half((j + 1) * 1.0));
    //             num1 = h2cos(__hmul2(half2_2_pi, num1));
    //             __half2 num2 = __halves2half2(__float2half(j * 1.0), __float2half((j + 1) * 1.0));
    //             num2 = h2sin(__hmul2(half2_2_pi, num2));

    //             temp[j] += x[0 * x_stride_0 + blockIdx.x * x_stride_1 + 0 * x_stride_2 + blockIdx.y * x_stride_3 + k * x_stride_4 + j * x_stride_5]
    //             * complex_half_t(num1.x, num2.x); 
    //             temp[j + 1 ] += x[0 * x_stride_0 + blockIdx.x * x_stride_1 + 0 * x_stride_2 + blockIdx.y * x_stride_3 + k * x_stride_4 + j * x_stride_5]
    //             * complex_half_t(num1.y, num2.y); 
    //         }
    //     }

    //     __syncthreads();

    //     for(int j = threadIdx.x; j < l_m; j+=blockDim.x){
    //         out[0 * out_stride_0 + blockIdx.x * out_stride_1 + 0 * out_stride_2 + blockIdx.y * out_stride_3 + j * out_stride_4] = temp[j];
    //     }
    // }


    extern __shared__ complex_half_t temp[];
    complex_half_t* scratch = &temp[m * l_m];
    complex_half_t acc;

    for(int b = threadIdx.y; b < B; b++){
        for(int k = 0; k < m; k++){
            for(int j = threadIdx.x; j <  l_m; j+=blockDim.x){
                if(threadIdx.x < l_m){
                    temp[k * l_m + j] = x[b * x_stride_0 + blockIdx.x * x_stride_1 + 0 * x_stride_2 + blockIdx.y * x_stride_3 + k * x_stride_4 + j * x_stride_5];
                }
            }
        }
        __syncthreads();

        for(int c=0; c < m; c++){
            for(int j = threadIdx.x; j <  l_m; j+=blockDim.x){
                scratch[j] = temp[j]
                        * complex_half_t(__float2half(j), __float2half(j)); 
                for(int k = 1; k < m; k++){
                    if(threadIdx.x < l_m)
                        scratch[j] += temp[k * l_m + j]
                        * complex_half_t(__float2half(j), __float2half(j)); 
                }
            }


            __syncthreads();

            for(int j = threadIdx.x; j < l_m; j+=blockDim.x){
                if(threadIdx.x < l_m)
                    out[b * out_stride_0 + blockIdx.x * out_stride_1 + c * out_stride_2 + blockIdx.y * out_stride_3 + j * out_stride_4] = scratch[j];
            }
        }
        
    }
}


// __device__ __forceinline__ __half taylor_(__half x) {
//     return __float2half(sinf(__half2float(x)));
// }

// template <int BLOCK_DIM_X>
// __global__ void butterfly_taylor_cuda_kernel(
//     const c10::complex<at::Half> *__restrict__ x,
//     c10::complex<at::Half> *__restrict__ out,
//     uint B,
//     uint N,
//     uint m,
//     uint l_m,
//     uint x_stride_0,
//     uint x_stride_1,
//     uint x_stride_2,
//     uint x_stride_3,
//     uint x_stride_4,
//     uint x_stride_5,
//     uint out_stride_0,
//     uint out_stride_1,
//     uint out_stride_2,
//     uint out_stride_3,
//     uint out_stride_4
// )
// {

//     __shared__ complex_half_t temp[BLOCK_DIM_X];

//     for(int b = threadIdx.y; b < B; b++){
//         for(int j = threadIdx.x; j < l_m; j+=blockDim.x){
//             temp[j] = complex_half_t(0.0f, 0.0f);
//         }

//         __syncthreads();


//         for(int k = 0; k < m; k++){
//             for(int j = threadIdx.x; j <  l_m; j+=blockDim.x){
//                 temp[j] += x[0 * x_stride_0 + blockIdx.x * x_stride_1 + 0 * x_stride_2 + blockIdx.y * x_stride_3 + k * x_stride_4 + j * x_stride_5]
//                 * complex_half_t(hcos(__hmul(__hmul(__float2half(2.0f),  __float2half(3.141592654)) , __float2half(j * 1.0))), __hneg(hsin(__hmul(__hmul(__float2half(2.0f) , __float2half(3.141592654)), __float2half(j * 1.0))))); 
//             }
//         }

//         __syncthreads();

//         for(int j = threadIdx.x; j < l_m; j+=blockDim.x){
//             out[0 * out_stride_0 + blockIdx.x * out_stride_1 + 0 * out_stride_2 + blockIdx.y * out_stride_3 + j * out_stride_4] = temp[j];
//         }
//     }
// }


__global__ void const_compute_kernel(
    c10::complex<at::Half> *__restrict__ out,
    uint B,
    uint N,
    uint m,
    uint l_m,
    uint out_stride_0,
    uint out_stride_1,
    uint out_stride_2,
    uint out_stride_3,
    uint out_stride_4
)
{

    const __half2 half2_2_pi = __halves2half2(__float2half(6.28318530718), __float2half(6.28318530718));
    extern __shared__ complex_half_t temp[];

    for(int b = threadIdx.y; b < B; b++){
        for(int j = threadIdx.x; j < l_m; j+=blockDim.x){
            temp[j] = complex_half_t(0.0f, 0.0f);
        }

        __syncthreads();


        for(int k = 0; k < m; k++){
            for(int j = 2 * threadIdx.x; j <  l_m; j+= 2 * blockDim.x){
                __half2 num1 = __halves2half2(__float2half(j * 1.0), __float2half((j + 1) * 1.0));
                num1 = h2cos(__hmul2(half2_2_pi, num1));
                __half2 num2 = __halves2half2(__float2half(j * 1.0), __float2half((j + 1) * 1.0));
                num2 = h2sin(__hmul2(half2_2_pi, num2));

                temp[j] += complex_half_t(num1.y, num2.y)
                * complex_half_t(num1.x, num2.x); 
                temp[j + 1 ] += complex_half_t(num1.y, num2.y)
                * complex_half_t(num1.x, num2.x); 
            }
        }

        __syncthreads();

        for(int j = threadIdx.x; j < l_m; j+=blockDim.x){
            out[0 * out_stride_0 + blockIdx.x * out_stride_1 + 0 * out_stride_2 + blockIdx.y * out_stride_3 + j * out_stride_4] = temp[j];
        }
    }
}