#include <torch/extension.h>

#include <vector>
#include <stdio.h>
#include <mma.h>
#include <cuda_fp16.h>
#include <cuda_bf16.h>
#include "shared.h"

using namespace nvcuda;

template <int TILE_H>
__global__ void butterfly_ifft_padded_cuda_kernel_64(
    const __half2 *__restrict__ x_real,
    const __half2 *__restrict__ x_imag,
    const complex_half_t *__restrict__ d_f,
    const __half2 *__restrict__ twiddle_factors_real,
    const __half2 *__restrict__ twiddle_factors_imag,
    __half2 *__restrict__ out_real,
    uint B,
    uint H,
    int M)
{
    const int max_idx = M / 2; //actually should be -1 since indices are 0-based but we are using < instead of <=
    const int out_offset = blockIdx.y * H * TILE_H * M/2 + blockIdx.z * TILE_H * M/2;
    const int in_offset = blockIdx.y * H * TILE_H * 64 * 32 * 512 + blockIdx.z * TILE_H * 64 * 32 * 512;
    int idx;
    int t_offset;
    int out_t_offset;
    int shared_offset;
    const int N = 64;

    extern __shared__ half x_real_shared[];
    half *x_imag_shared = &x_real_shared[N * N];
    half *d_f_real = &x_imag_shared[N * N];
    half *d_f_imag = &d_f_real[N * N];
    half *twiddles_real_shared = &d_f_imag[N * N];
    half *twiddles_imag_shared = &twiddles_real_shared[N * N];
    half *out_real_shared = &twiddles_imag_shared[N * N];

    half tmp_real, tmp_imag;

    wmma::fragment<wmma::matrix_a, 16, 16, 16, half, wmma::col_major> a_frag_real[4][4];
    wmma::fragment<wmma::matrix_a, 16, 16, 16, half, wmma::col_major> a_frag_imag[4][4];
    wmma::fragment<wmma::matrix_b, 16, 16, 16, half, wmma::row_major> tw_frag_real[4];
    wmma::fragment<wmma::matrix_b, 16, 16, 16, half, wmma::row_major> tw_frag_imag[4];
    wmma::fragment<wmma::matrix_b, 16, 16, 16, half, wmma::row_major> b_frag_real[4];
    wmma::fragment<wmma::matrix_b, 16, 16, 16, half, wmma::row_major> b_frag_imag[4];
    wmma::fragment<wmma::accumulator, 16, 16, 16, half> acc_frag_real[4];

    // #pragma unroll
    for (int i = threadIdx.y; i < N; i+=blockDim.y)
    {
        idx = i * 32 * 512 + blockIdx.x * 32 + threadIdx.x;
        shared_offset = i * 32 + threadIdx.x;
        reinterpret_cast<__half2 *>(twiddles_real_shared)[shared_offset] = twiddle_factors_real[idx];
        reinterpret_cast<__half2 *>(twiddles_imag_shared)[shared_offset] = twiddle_factors_imag[idx];

        // #pragma unroll
        shared_offset = i * 64 + threadIdx.x;
        d_f_real[shared_offset] = d_f[shared_offset].real();
        d_f_imag[shared_offset] = d_f[shared_offset].imag();

        d_f_real[shared_offset + blockDim.x] = d_f[shared_offset + blockDim.x].real();
        d_f_imag[shared_offset + blockDim.x] = d_f[shared_offset + blockDim.x].imag();
    }

    __syncthreads();

    for (int i = 0; i < 4; i++)
    {
#pragma unroll
        for (int j = 0; j < 4; j++)
        {
            wmma::load_matrix_sync(a_frag_real[i][j], d_f_real + j * N * 16 + i * 16, N);
            wmma::load_matrix_sync(a_frag_imag[i][j], d_f_imag + j * N * 16 + i * 16, N);
        }
        wmma::load_matrix_sync(tw_frag_real[i], twiddles_real_shared + i * N * 16 + threadIdx.y * 16, N);
        wmma::load_matrix_sync(tw_frag_imag[i], twiddles_imag_shared + i * N * 16 + threadIdx.y * 16, N);
    }

    for (int t = 0; t < TILE_H; t++)
    {

        out_t_offset = t * M/2;
        t_offset = t * 64 * 32 * 512;

        for (int i = threadIdx.y; i < N; i+=blockDim.y)
        {
            idx = i * 32 * 512 + blockIdx.x * 32 + threadIdx.x;
            shared_offset = i * 32 + threadIdx.x;
            reinterpret_cast<__half2 *>(x_real_shared)[shared_offset] = x_real[idx + in_offset + t_offset];
            reinterpret_cast<__half2 *>(x_imag_shared)[shared_offset] = x_imag[idx + in_offset + t_offset];
        }

        __syncthreads();

        for (int i = 0; i < 4; i++)
        {
            wmma::load_matrix_sync(b_frag_real[i], x_real_shared + i * N * 16 + threadIdx.y * 16, N);
            wmma::load_matrix_sync(b_frag_imag[i], x_imag_shared + i * N * 16 + threadIdx.y * 16, N);
        }

        for (int j = 0; j < 4; j++)
        {
            for (int k = 0; k < tw_frag_real[j].num_elements; k++)
            {
                tmp_real = __hsub(__hmul(tw_frag_real[j].x[k], b_frag_real[j].x[k]), __hmul(tw_frag_imag[j].x[k], b_frag_imag[j].x[k]));
                tmp_imag = __hadd(__hmul(tw_frag_real[j].x[k], b_frag_imag[j].x[k]), __hmul(tw_frag_imag[j].x[k], b_frag_real[j].x[k]));
                b_frag_real[j].x[k] = tmp_real;
                b_frag_imag[j].x[k] = tmp_imag;
            }
        }

        for (int i = 0; i < 4; i++)
        {
            wmma::fill_fragment(acc_frag_real[i], __float2half(0.0f));

// bd
#pragma unroll
            for (int k = 0; k < 4; k++)
            {
                wmma::mma_sync(acc_frag_real[i], a_frag_imag[i][k], b_frag_imag[k], acc_frag_real[i]);
            }

            for (int k = 0; k < acc_frag_real[i].num_elements; k++)
            {
                acc_frag_real[i].x[k] = __hneg(acc_frag_real[i].x[k]);
            }
        }

        for (int i = 0; i < 4; i++)
        {
// ac - bd
#pragma unroll
            for (int k = 0; k < 4; k++)
            {
                wmma::mma_sync(acc_frag_real[i], a_frag_real[i][k], b_frag_real[k], acc_frag_real[i]);
            }
        }

#pragma unroll
        for (int i = 0; i < 4; i++)
        {
            wmma::store_matrix_sync(out_real_shared + i * N * 16 + threadIdx.y * 16, acc_frag_real[i], N, wmma::mem_row_major);
        }

        __syncthreads();

#pragma unroll
        for (int i = threadIdx.y; i < N; i+=blockDim.y)
        {
            idx = i * 32 * 512 + blockIdx.x * 32 + threadIdx.x;
            shared_offset = i * 32 + threadIdx.x;

            if(idx < max_idx){
                out_real[out_offset + out_t_offset + idx] = reinterpret_cast<__half2 *>(out_real_shared)[shared_offset];
            }
        }

        __syncthreads();
    }
}

__global__ void butterfly_ifft_padded_cuda_kernel_32(
    const __half2 *__restrict__ x_real,
    const __half2 *__restrict__ x_imag,
    const complex_half_t *__restrict__ d_f,
    const __half2 *__restrict__ twiddle_factors_real,
    const __half2 *__restrict__ twiddle_factors_imag,
    __half2 *__restrict__ out_real,
    uint B,
    uint H,
    int M)
{
    const int max_idx = M / 2; //actually should be -1 since indices are 0-based but we are using < instead of <=
    const int N  = 32;
    int idx;
    int shared_offset;

    const int out_offset  =  blockIdx.y * H * M / 2 + blockIdx.z * M / 2; 
    const int in_offset = blockIdx.y * H * 32 * 32 * 512 + blockIdx.z * 32 * 32 * 512;


    __shared__ half x_real_shared[32 * 64];
    __shared__ half x_imag_shared[32 * 64];
    __shared__ half d_f_real[32 * 32];
    __shared__ half d_f_imag[32 * 32];
    __shared__ half twiddles_real_shared[32 * 64];
    __shared__ half twiddles_imag_shared[32 * 64];
    __shared__ half out_real_shared[32 * 64];

    // #pragma unroll
    for (int i = threadIdx.y; i < N; i+=blockDim.y)
    {
        idx = i * 32 * 512 + blockIdx.x * 32 + threadIdx.x;
        int shared_offset = i * 32 + threadIdx.x;
        reinterpret_cast<__half2 *>(x_real_shared)[shared_offset] = x_real[in_offset  + idx];
        reinterpret_cast<__half2 *>(x_imag_shared)[shared_offset] = x_imag[in_offset  + idx];
        reinterpret_cast<__half2 *>(twiddles_real_shared)[shared_offset] = twiddle_factors_real[idx];
        reinterpret_cast<__half2 *>(twiddles_imag_shared)[shared_offset] = twiddle_factors_imag[idx];

        // #pragma unroll
        d_f_real[shared_offset] = d_f[shared_offset].real();
        d_f_imag[shared_offset] = d_f[shared_offset].imag();
    }

    __syncthreads();

    if (threadIdx.y < N / 16)
    {
        half tmp_real, tmp_imag;

        wmma::fragment<wmma::matrix_a, 16, 16, 16, half, wmma::col_major> a_frag_real[2][2];
        wmma::fragment<wmma::matrix_a, 16, 16, 16, half, wmma::col_major> a_frag_imag[2][2];
        wmma::fragment<wmma::matrix_b, 16, 16, 16, half, wmma::row_major> tw_frag_real[2][2];
        wmma::fragment<wmma::matrix_b, 16, 16, 16, half, wmma::row_major> tw_frag_imag[2][2];
        wmma::fragment<wmma::matrix_b, 16, 16, 16, half, wmma::row_major> b_frag_real[2][2];
        wmma::fragment<wmma::matrix_b, 16, 16, 16, half, wmma::row_major> b_frag_imag[2][2];
        wmma::fragment<wmma::accumulator, 16, 16, 16, half> acc_frag_real[2][2];
        wmma::fragment<wmma::accumulator, 16, 16, 16, half> acc_frag_imag[2][2];

        int t = threadIdx.y * 32;

        for (int i = 0; i < 2; i++)
        {
            for (int j = 0; j < 2; j++)
            {
                wmma::load_matrix_sync(a_frag_real[i][j], d_f_real + j * N * 16 + i * 16, N);
                wmma::load_matrix_sync(a_frag_imag[i][j], d_f_imag + j * N * 16 + i * 16, N);
                wmma::load_matrix_sync(b_frag_real[i][j], x_real_shared + i * 2 * N * 16 + j * 16 + t, 2 * N);
                wmma::load_matrix_sync(b_frag_imag[i][j], x_imag_shared + i * 2 * N * 16 + j * 16 + t, 2 * N);
                wmma::load_matrix_sync(tw_frag_real[i][j], twiddles_real_shared + i * 2 * N * 16 + j * 16 + t, 2 * N);
                wmma::load_matrix_sync(tw_frag_imag[i][j], twiddles_imag_shared + i * 2 * N * 16 + j * 16 + t, 2 * N);
            }
        }

        for (int i = 0; i < 2; i++)
        {
            for (int j = 0; j < 2; j++)
            {
                for (int k = 0; k < tw_frag_real[i][j].num_elements; k++)
                {
                    tmp_real = __hsub(__hmul(tw_frag_real[i][j].x[k], b_frag_real[i][j].x[k]), __hmul(tw_frag_imag[i][j].x[k], b_frag_imag[i][j].x[k]));
                    tmp_imag = __hadd(__hmul(tw_frag_real[i][j].x[k], b_frag_imag[i][j].x[k]), __hmul(tw_frag_imag[i][j].x[k], b_frag_real[i][j].x[k]));
                    b_frag_real[i][j].x[k] = tmp_real;
                    b_frag_imag[i][j].x[k] = tmp_imag;
                }
            }
        }

        for (int i = 0; i < 2; i++)
        {
            for (int j = 0; j < 2; j++)
            {
                wmma::fill_fragment(acc_frag_real[i][j], __float2half(0.0f));

                // bd
                for (int k = 0; k < 2; k++)
                {
                    wmma::mma_sync(acc_frag_real[i][j], a_frag_imag[i][k], b_frag_imag[k][j], acc_frag_real[i][j]);
                }

                for (int k = 0; k < acc_frag_real[i][j].num_elements; k++)
                {
                    acc_frag_real[i][j].x[k] = __hneg(acc_frag_real[i][j].x[k]);
                }
            }
        }

        for (int i = 0; i < 2; i++)
        {
            for (int j = 0; j < 2; j++)
            {
                // ac - bd
                for (int k = 0; k < 2; k++)
                {
                    wmma::mma_sync(acc_frag_real[i][j], a_frag_real[i][k], b_frag_real[k][j], acc_frag_real[i][j]);
                }
            }
        }

        for (int i = 0; i < 2; i++)
        {
            for (int j = 0; j < 2; j++)
            {
                wmma::store_matrix_sync(out_real_shared + i * 2 * N * 16 + j * 16 + t, acc_frag_real[i][j], 2 * N, wmma::mem_row_major);
            }
        }
    }

    __syncthreads();

#pragma unroll
    for (int i = threadIdx.y; i < N; i+=blockDim.y)
    {
        idx = i * 32 * 512 + blockIdx.x * 32 + threadIdx.x;
        shared_offset = i * 32 + threadIdx.x;

        if(idx < max_idx)
            out_real[idx +  out_offset] = reinterpret_cast<__half2 *>(out_real_shared)[shared_offset];
        
    }
}


template <int TILE_H>
__global__ void butterfly_ifft_padded_cuda_kernel_128(
    const __half2 *__restrict__ x_real,
    const __half2 *__restrict__ x_imag,
    const complex_half_t *__restrict__ d_f,
    const __half2 *__restrict__ twiddle_factors_real,
    const __half2 *__restrict__ twiddle_factors_imag,
    __half2 *__restrict__ out_real,
    uint B,
    uint H,
    int M)
{
    const int max_idx = M / 2; //actually should be -1 since indices are 0-based but we are using < instead of <=
    const int out_offset = blockIdx.y * H * TILE_H * M/2 + blockIdx.z * TILE_H *  M/2;
    const int in_offset = blockIdx.y * H * TILE_H * 128 * 32 * 512 + blockIdx.z * TILE_H * 128 * 32 * 512;
    const int N = 128;
    int idx;
    int t_offset;
    int out_t_offset;
    int shared_offset;


    extern __shared__ half real_shared[];
    half *imag_shared = &real_shared[128 * 128];

    half tmp_real, tmp_imag;

    wmma::fragment<wmma::matrix_a, 16, 16, 16, half, wmma::col_major> a_frag_real[8][8];
    wmma::fragment<wmma::matrix_a, 16, 16, 16, half, wmma::col_major> a_frag_imag[8][8];
    wmma::fragment<wmma::matrix_b, 16, 16, 16, half, wmma::row_major> tw_frag_real[8];
    wmma::fragment<wmma::matrix_b, 16, 16, 16, half, wmma::row_major> tw_frag_imag[8];
    wmma::fragment<wmma::matrix_b, 16, 16, 16, half, wmma::row_major> b_frag_real[8];
    wmma::fragment<wmma::matrix_b, 16, 16, 16, half, wmma::row_major> b_frag_imag[8];
    wmma::fragment<wmma::accumulator, 16, 16, 16, half> acc_frag_real[8];
    wmma::fragment<wmma::accumulator, 16, 16, 16, half> acc_frag_imag[8];

    for (int i = threadIdx.y; i < N; i+=blockDim.y)
    {
        for(int j=0; j< 4; j++){
            shared_offset = i * 128 + threadIdx.x + j * blockDim.x;
            real_shared[shared_offset] = d_f[shared_offset].real();
            imag_shared[shared_offset] = d_f[shared_offset].imag();
        }
    }

    __syncthreads();


    for (int i = 0; i < 8; i++){
        for (int j = 0; j < 8; j++){
            wmma::load_matrix_sync(a_frag_real[i][j], real_shared + j * 128 * 16 + i * 16, 128);
            wmma::load_matrix_sync(a_frag_imag[i][j], imag_shared + j * 128 * 16 + i * 16, 128);
        }
    }


    __syncthreads();

    for (int i = threadIdx.y; i < N; i+=blockDim.y)
    {
        for(int j=0; j< 2; j++){
            idx = i * 32 * 512 + j * blockDim.x + blockIdx.x * 64 + threadIdx.x;
            shared_offset = i * 64 + threadIdx.x + j * blockDim.x;    
            reinterpret_cast<__half2*>(real_shared)[shared_offset] = twiddle_factors_real[idx];
            reinterpret_cast<__half2*>(imag_shared)[shared_offset] = twiddle_factors_imag[idx];
        }
    }

    __syncthreads();


    for (int i = 0; i < 8; i++){
        wmma::load_matrix_sync(tw_frag_real[i], real_shared + i * 128 * 16 + threadIdx.y * 16, 128);
        wmma::load_matrix_sync(tw_frag_imag[i], imag_shared + i * 128 * 16 + threadIdx.y * 16, 128);
    }

    __syncthreads();

    for (int t = 0; t < TILE_H; t++)
    {

        out_t_offset = t * M/2;
        t_offset = t * 128 * 32 * 512;

        for (int i = threadIdx.y; i < N; i+=blockDim.y)
        {
            for(int j=0; j< 2; j++){
                idx = i * 32 * 512 + j * blockDim.x + blockIdx.x * 64 + threadIdx.x;
                shared_offset = i * 64 + threadIdx.x + j * blockDim.x;  
                reinterpret_cast<__half2*>(real_shared)[shared_offset] = x_real[idx + in_offset + t_offset];
                reinterpret_cast<__half2*>(imag_shared)[shared_offset] = x_imag[idx + in_offset + t_offset];
            }
        }

        __syncthreads();

        for (int i = 0; i < 8; i++)
        {
            wmma::load_matrix_sync(b_frag_real[i], real_shared + i * N * 16 + threadIdx.y * 16, N);
            wmma::load_matrix_sync(b_frag_imag[i], imag_shared + i * N * 16 + threadIdx.y * 16, N);
        }


        for (int j = 0; j < 8; j++)
        {
            for (int k = 0; k < tw_frag_real[j].num_elements; k++)
            {
                tmp_real = __hsub(__hmul(tw_frag_real[j].x[k], b_frag_real[j].x[k]), __hmul(tw_frag_imag[j].x[k], b_frag_imag[j].x[k]));
                tmp_imag = __hadd(__hmul(tw_frag_real[j].x[k], b_frag_imag[j].x[k]), __hmul(tw_frag_imag[j].x[k], b_frag_real[j].x[k]));
                b_frag_real[j].x[k] = tmp_real;
                b_frag_imag[j].x[k] = tmp_imag;
            }
        }

        for (int i = 0; i < 8; i++)
        {
            wmma::fill_fragment(acc_frag_real[i], __float2half(0.0f));

// bd
#pragma unroll
            for (int k = 0; k < 8; k++)
            {
                wmma::mma_sync(acc_frag_real[i], a_frag_imag[i][k], b_frag_imag[k], acc_frag_real[i]);
            }

            for (int k = 0; k < acc_frag_real[i].num_elements; k++)
            {
                acc_frag_real[i].x[k] = __hneg(acc_frag_real[i].x[k]);
            }
        }

        for (int i = 0; i < 8; i++)
        {
// ac - bd
#pragma unroll
            for (int k = 0; k < 8; k++)
            {
                wmma::mma_sync(acc_frag_real[i], a_frag_real[i][k], b_frag_real[k], acc_frag_real[i]);
            }
        }

#pragma unroll
        for (int i = 0; i < 8; i++)
        {
            //wmma::store_matrix_sync(real_shared + i * N * 16 + threadIdx.y * 16, acc_frag_real[i], N, wmma::mem_row_major);
            wmma::store_matrix_sync(real_shared + i * N * 16 + threadIdx.y * 16, acc_frag_real[i], N, wmma::mem_row_major);
        }

        __syncthreads();

#pragma unroll
        for (int i = threadIdx.y; i < N; i+=blockDim.y)
        {
            for(int j=0; j< 2; j++){
                idx = i * 32 * 512 + j * blockDim.x + blockIdx.x * 64 + threadIdx.x;
                shared_offset = i * 64 + threadIdx.x + j * blockDim.x;
                if(idx < max_idx) 
                    out_real[idx + out_offset + out_t_offset] = reinterpret_cast<__half2*>(real_shared)[shared_offset];
            }
        }

        __syncthreads();
    }
}

torch::Tensor butterfly_ifft_padded_cuda(
    torch::Tensor x_real,
    torch::Tensor x_imag,
    torch::Tensor d_f,
    torch::Tensor twiddle_factors_real,
    torch::Tensor twiddle_factors_imag,
    int N
    )
{

    uint B = x_real.size(0);
    uint H = x_real.size(1);
    // uint m = x.size(1);

    // const int TILE_SIZE = 16;

    dim3 gridDim;
    dim3 blockDim;

    // uint N = x_real.size(2);
    gridDim.y = B;

    blockDim.x = 32;
    blockDim.y = 4;
    gridDim.x = 512;
    gridDim.z = H;

    const int d_f_size = d_f.size(0);
    const int TILE_H = 16;
    torch::Tensor out_real = torch::empty({B, H, N}, x_real.options());
    
    

    switch(d_f_size){
        case 32:
            butterfly_ifft_padded_cuda_kernel_32<<<gridDim, blockDim>>>(
            static_cast<__half2 *>(x_real.data_ptr()),
            static_cast<__half2 *>(x_imag.data_ptr()),
            static_cast<complex_half_t *>(d_f.data_ptr()),
            static_cast<__half2 *>(twiddle_factors_real.data_ptr()),
            static_cast<__half2 *>(twiddle_factors_imag.data_ptr()),
            static_cast<__half2 *>(out_real.data_ptr()),
            B,
            H,
            N);
            break;

        case 64:
            gridDim.z = H / TILE_H;
            cudaFuncSetAttribute(&butterfly_ifft_padded_cuda_kernel_64<TILE_H>, cudaFuncAttributeMaxDynamicSharedMemorySize, 65368);
            butterfly_ifft_padded_cuda_kernel_64<TILE_H><<<gridDim, blockDim, 65368>>>(
                static_cast<__half2 *>(x_real.data_ptr()),
                static_cast<__half2 *>(x_imag.data_ptr()),
                static_cast<complex_half_t *>(d_f.data_ptr()),
                static_cast<__half2 *>(twiddle_factors_real.data_ptr()),
                static_cast<__half2 *>(twiddle_factors_imag.data_ptr()),
                static_cast<__half2 *>(out_real.data_ptr()),
                B,
                H,
                N);
            break;
        case 128:
            blockDim.x = 32;
            blockDim.y = 8;
            gridDim.x = 256;
            gridDim.z = H / TILE_H;
            cudaFuncSetAttribute(&butterfly_ifft_padded_cuda_kernel_128<TILE_H>, cudaFuncAttributeMaxDynamicSharedMemorySize, 65368);

            butterfly_ifft_padded_cuda_kernel_128<TILE_H><<<gridDim, blockDim, 65368>>>(
                static_cast<__half2 *>(x_real.data_ptr()),
                static_cast<__half2 *>(x_imag.data_ptr()),
                static_cast<complex_half_t *>(d_f.data_ptr()),
                static_cast<__half2 *>(twiddle_factors_real.data_ptr()),
                static_cast<__half2 *>(twiddle_factors_imag.data_ptr()),
                static_cast<__half2 *>(out_real.data_ptr()),
                B,
                H,
                N);
            break;

        default:
            printf("Invalid d_f_size: %d\n", d_f_size);
            break;
    }
    
    return out_real;
}
