#include <cuda_runtime.h>
#include <mma.h>
#include <cstdio>
#include <cstdlib>
#include <cublas_v2.h>

#include "twiddle.cuh"

using namespace nvcuda;
const int WARP_SIZE = 32, WMMA_M = 16, WMMA_N = 16, WMMA_K = 16, CONT_SIZE = 16;

inline __device__ void mma_twiddle_16(
    wmma::fragment<wmma::matrix_a, WMMA_M, WMMA_N, WMMA_K, half, wmma::row_major> &frag_F_real,
    wmma::fragment<wmma::matrix_a, WMMA_M, WMMA_N, WMMA_K, half, wmma::row_major> &frag_F_imag
) {
    int raw_row = threadIdx.x % 4 * 2;
    int raw_col = (threadIdx.x % 32) / 4;
    #pragma unroll
    for (int j = 0; j < 8; ++j) {
        int row = raw_row + j % 4 / 2 * 8 + j % 2;
        int col = raw_col + j / 4 * 8;
        float2 twiddle = cufftdx::database::detail::lut_sp_2_16[(row * col) % 16];
        frag_F_real.x[j] = __float2half(twiddle.x);
        frag_F_imag.x[j] = __float2half(twiddle.y);
    }
}

__device__ inline void
complex_mul(wmma::fragment<wmma::matrix_a, WMMA_M, WMMA_N, WMMA_K, half, wmma::row_major> &frag_F_real, wmma::fragment<wmma::matrix_a, WMMA_M, WMMA_N, WMMA_K, half, wmma::row_major> &frag_F_imag,
            wmma::fragment<wmma::matrix_b, WMMA_M, WMMA_N, WMMA_K, half, wmma::col_major> &frag_in_real, wmma::fragment<wmma::matrix_b, WMMA_M, WMMA_N, WMMA_K, half, wmma::col_major> &frag_in_imag,
            wmma::fragment<wmma::accumulator, WMMA_M, WMMA_N, WMMA_K, half> &frag_out_real, wmma::fragment<wmma::accumulator, WMMA_M, WMMA_N, WMMA_K, half> &frag_out_imag)
{
    wmma::fill_fragment(frag_out_real, __float2half(0.f));
    wmma::fill_fragment(frag_out_imag, __float2half(0.f));

    wmma::mma_sync(frag_out_real, frag_F_imag, frag_in_imag, frag_out_real);
    for (int i = 0; i < frag_out_real.num_elements; i++)
        frag_out_real.x[i] = __hneg(frag_out_real.x[i]);
    wmma::mma_sync(frag_out_real, frag_F_real, frag_in_real, frag_out_real);

    wmma::mma_sync(frag_out_imag, frag_F_real, frag_in_imag, frag_out_imag);
    wmma::mma_sync(frag_out_imag, frag_F_imag, frag_in_real, frag_out_imag);
}

__device__ __host__ inline half2 W_N_K(int N, int K) {
    float cos, sin;
    sincospif(float(2 * K) / N, &sin, &cos);
    // half2 t = {cosf(2 * M_PI * K / N), -sinf(2 * M_PI * K / N)};
    half2 t = {__float2half(cos), __float2half(-sin)};
    return t;
}

__device__ inline half2 const cmul(const half2 &a, const half2 &b) {
    return {__hsub(__hmul(a.x, b.x), __hmul(a.y, b.y)), __hadd(__hmul(a.x, b.y), __hmul(a.y, b.x))};
}

// template <int NUM_WARP>
__device__ void layer_256_0_A100(float2 *in)
{
    wmma::fragment<wmma::matrix_a, WMMA_M, WMMA_N, WMMA_K, half, wmma::row_major> frag_F_real;
    wmma::fragment<wmma::matrix_a, WMMA_M, WMMA_N, WMMA_K, half, wmma::row_major> frag_F_imag;
    mma_twiddle_16(frag_F_real, frag_F_imag);

    int raw_row = threadIdx.x % 4 * 2;
    int raw_col = (threadIdx.x % 32) / 4;
    int warp_idx = threadIdx.x / 32;
    // half2 twiddle_unit = W_N_K(256, raw_col);

    /* opt test
    for (int i = 0; i < 256 * CONT_SIZE; i += NUM_WARP * 32)
    {
        int eid = i + t_block;
        smem_in[eid] = in[block_start + eid];
    }
    __syncthreads();
    for (int i = 0; i < 256 * CONT_SIZE; i += NUM_WARP * 32)
    {
        int eid = i + t_block;
        smem_in[eid / 32 * 32 + eid % 32 / 2 + eid % 32 % 2 * 16] = smem_in[eid];
    }
    __syncthreads();
    */

    wmma::fragment<wmma::accumulator, WMMA_M, WMMA_N, WMMA_K, half> frag_out_real;
    wmma::fragment<wmma::accumulator, WMMA_M, WMMA_N, WMMA_K, half> frag_out_imag;
    wmma::fragment<wmma::matrix_b, WMMA_M, WMMA_N, WMMA_K, half, wmma::col_major> frag_in_real;
    wmma::fragment<wmma::matrix_b, WMMA_M, WMMA_N, WMMA_K, half, wmma::col_major> frag_in_imag;

    int warp_start = warp_idx * 256;

    __syncthreads();
    for (int j = 0; j < 8; ++j) {
        int row = raw_row + j % 4 / 2 * 8 + j % 2;
        int col = raw_col + j / 4 * 8;
        half2 ele = __float22half2_rn(in[warp_start + row + col * 16]);
        // // half2 ele = smem_in[warp_start + row + col * 16]; // opt test
        frag_in_real.x[j] = ele.x;
        frag_in_imag.x[j] = ele.y;
    }

    complex_mul(frag_F_real, frag_F_imag, frag_in_real, frag_in_imag, frag_out_real, frag_out_imag);

    __syncthreads();
    // wmma::store_matrix_sync((half *)(in + warp_start), frag_out_real, 16, wmma::mem_row_major);
    // wmma::store_matrix_sync((half *)(in + warp_start) + 256, frag_out_imag, 16, wmma::mem_row_major);

    // wmma::load_matrix_sync(frag_in_real, (half *)(in + warp_start), 16);
    // wmma::load_matrix_sync(frag_in_imag, (half *)(in + warp_start) + 256, 16);

    // half2 twiddle_factor = {1.0, 0};
    // for (int j = 0; j < 16; ++j)
    // {
    //     int row = j;
    //     int col = raw_col;
    //     half2 in_ele = {frag_in_real.x[j], frag_in_imag.x[j]};
    //     in_ele = cmul(in_ele, twiddle_factor);
    //     frag_in_real.x[j] = in_ele.x;
    //     frag_in_imag.x[j] = in_ele.y;
    //     twiddle_factor = cmul(twiddle_factor, twiddle_unit);
    // }
    for (int j = 0; j < 8; ++j)
    {
        int row = raw_row + j % 4 / 2 * 8 + j % 2;
        int col = raw_col + j / 4 * 8;
        half2 in_ele = {frag_in_real.x[j], frag_in_imag.x[j]};
        in_ele = cmul(in_ele, W_N_K(256, row * col));
        frag_in_real.x[j] = in_ele.x;
        frag_in_imag.x[j] = in_ele.y;
    }

    complex_mul(frag_F_real, frag_F_imag, frag_in_real, frag_in_imag, frag_out_real, frag_out_imag);

    // int raw_row = threadIdx.x / 16 * 4 + threadIdx.x % 8 / 4 * 8 + threadIdx.x % 4;
    // raw_col = threadIdx.x % 16 / 8 * 8;
    __syncthreads();
    for (int j = 0; j < 8; ++j)
    {
        int col = (raw_row + j / 4 * 8 + j % 2) % 16;
        int row = (raw_col + j % 4 / 2 * 8) % 16;
        in[warp_start + row * 16 + col] = __half22float2({frag_out_real.x[j], frag_out_imag.x[j]});
        // smem_in[warp_start + row * 16 + col] = {frag_out_real.x[j], frag_out_imag.x[j]}; //opt test
    }

    /* opt test
    __syncthreads();
    for (int i = 0; i < 256 * CONT_SIZE; i += NUM_WARP * 32)
    {
        int eid = i + t_block;
        in[block_start + eid] = smem_in[eid];
    }
    */
}

// template <int CONT_SIZE, int NUM_WARP>
// __global__ void layer_512_0_A100(half2 *in, half *F_real, half *F_imag)
// {
//     extern __shared__ half2 smem_in[];
//     int t_block = threadIdx.x + threadIdx.y * blockDim.x;
//     int block_start = blockIdx.x * 512 * CONT_SIZE;

//     wmma::fragment<wmma::matrix_a, WMMA_M, WMMA_N, WMMA_K, half, wmma::row_major> frag_F_real;
//     wmma::fragment<wmma::matrix_a, WMMA_M, WMMA_N, WMMA_K, half, wmma::row_major> frag_F_imag;
//     wmma::load_matrix_sync(frag_F_real, F_real, 16);
//     wmma::load_matrix_sync(frag_F_imag, F_imag, 16);

//     int raw_row = threadIdx.x % 4 * 2;
//     int raw_col = threadIdx.x / 4;
//     half2 twiddle_two = W_N_K(512, t_block);

//     for (int i = 0; i < 512 * CONT_SIZE; i += NUM_WARP * 16 * 16)
//     {
//         wmma::fragment<wmma::accumulator, WMMA_M, WMMA_N, WMMA_K, half> frag_out_real;
//         wmma::fragment<wmma::accumulator, WMMA_M, WMMA_N, WMMA_K, half> frag_out_imag;
//         wmma::fragment<wmma::matrix_b, WMMA_M, WMMA_N, WMMA_K, half, wmma::col_major> frag_in_real;
//         wmma::fragment<wmma::matrix_b, WMMA_M, WMMA_N, WMMA_K, half, wmma::col_major> frag_in_imag;

//         int warp_start = i + threadIdx.y * 256;

//         for (int j = 0; j < 8; ++j)
//         {
//             int row = raw_row + j % 4 / 2 * 8 + j % 2;
//             int col = raw_col + j / 4 * 8;
//             half2 ele = in[block_start + warp_start + row + col * 16];
//             frag_in_real.x[8 + j] = frag_in_real.x[j] = ele.x;
//             frag_in_imag.x[8 + j] = frag_in_imag.x[j] = ele.y;
//         }

//         complex_mul(frag_F_real, frag_F_imag, frag_in_real, frag_in_imag, frag_out_real, frag_out_imag);

//         wmma::store_matrix_sync((half *)(smem_in + warp_start), frag_out_real, 16, wmma::mem_row_major);
//         wmma::store_matrix_sync((half *)(smem_in + warp_start) + 256, frag_out_imag, 16, wmma::mem_row_major);

//         wmma::load_matrix_sync(frag_in_real, (half *)(smem_in + warp_start), 16);
//         wmma::load_matrix_sync(frag_in_imag, (half *)(smem_in + warp_start) + 256, 16);

//         for (int j = 0; j < 8; ++j)
//         {
//             int row = raw_row + j % 4 / 2 * 8 + j % 2;
//             int col = raw_col + j / 4 * 8;
//             half2 in_ele = {frag_in_real.x[j], frag_in_imag.x[j]};
//             in_ele = cmul(in_ele, W_N_K(256, row * col));
//             frag_in_real.x[8 + j] = frag_in_real.x[j] = in_ele.x;
//             frag_in_imag.x[8 + j] = frag_in_imag.x[j] = in_ele.y;
//         }

//         complex_mul(frag_F_real, frag_F_imag, frag_in_real, frag_in_imag, frag_out_real, frag_out_imag);

//         for (int j = 0; j < 8; ++j)
//         {
//             int col = raw_row + j / 4 * 8 + j % 2;
//             int row = raw_col + j % 4 / 2 * 8;
//             smem_in[warp_start + row * 16 + col] = {frag_out_real.x[j], frag_out_imag.x[j]};
//         }
//     }

//     __syncthreads();
//     for (int i = 0; i < 512 * CONT_SIZE; i += NUM_WARP * 32 * 2)
//     {
//         int eid = i + t_block;
//         half2 ele_0 = smem_in[eid];
//         half2 ele_1 = cmul(smem_in[eid + 256], twiddle_two);
//         in[block_start + eid] = __hadd2(ele_0, ele_1);
//         in[block_start + eid + 256] = __hsub2(ele_0, ele_1);
//     }
// }

// template <int CONT_SIZE, int NUM_WARP>
// __global__ void layer_1024_0_A100(half2 *in, half *F_real, half *F_imag)
// {
//     extern __shared__ half2 smem_in[];
//     int t_block = threadIdx.x + threadIdx.y * blockDim.x;
//     int block_start = blockIdx.x * 1024 * CONT_SIZE;

//     wmma::fragment<wmma::matrix_a, WMMA_M, WMMA_N, WMMA_K, half, wmma::row_major> frag_F_real;
//     wmma::fragment<wmma::matrix_a, WMMA_M, WMMA_N, WMMA_K, half, wmma::row_major> frag_F_imag;
//     wmma::load_matrix_sync(frag_F_real, F_real, 16);
//     wmma::load_matrix_sync(frag_F_imag, F_imag, 16);

//     int raw_row = threadIdx.x % 4 * 2;
//     int raw_col = threadIdx.x / 4;

//     for (int i = 0; i < 1024 * CONT_SIZE; i += NUM_WARP * 16 * 16)
//     {
//         wmma::fragment<wmma::accumulator, WMMA_M, WMMA_N, WMMA_K, half> frag_out_real;
//         wmma::fragment<wmma::accumulator, WMMA_M, WMMA_N, WMMA_K, half> frag_out_imag;
//         wmma::fragment<wmma::matrix_b, WMMA_M, WMMA_N, WMMA_K, half, wmma::col_major> frag_in_real;
//         wmma::fragment<wmma::matrix_b, WMMA_M, WMMA_N, WMMA_K, half, wmma::col_major> frag_in_imag;

//         int warp_start = i + threadIdx.y * 256;

//         for (int j = 0; j < 8; ++j)
//         {
//             int row = raw_row + j % 4 / 2 * 8 + j % 2;
//             int col = raw_col + j / 4 * 8;
//             half2 ele = in[block_start + warp_start + row + col * 16];
//             frag_in_real.x[8 + j] = frag_in_real.x[j] = ele.x;
//             frag_in_imag.x[8 + j] = frag_in_imag.x[j] = ele.y;
//         }

//         complex_mul(frag_F_real, frag_F_imag, frag_in_real, frag_in_imag, frag_out_real, frag_out_imag);

//         wmma::store_matrix_sync((half *)(smem_in + warp_start), frag_out_real, 16, wmma::mem_row_major);
//         wmma::store_matrix_sync((half *)(smem_in + warp_start) + 256, frag_out_imag, 16, wmma::mem_row_major);

//         wmma::load_matrix_sync(frag_in_real, (half *)(smem_in + warp_start), 16);
//         wmma::load_matrix_sync(frag_in_imag, (half *)(smem_in + warp_start) + 256, 16);

//         for (int j = 0; j < 8; ++j)
//         {
//             int row = raw_row + j % 4 / 2 * 8 + j % 2;
//             int col = raw_col + j / 4 * 8;
//             half2 in_ele = {frag_in_real.x[j], frag_in_imag.x[j]};
//             in_ele = cmul(in_ele, W_N_K(256, row * col));
//             frag_in_real.x[8 + j] = frag_in_real.x[j] = in_ele.x;
//             frag_in_imag.x[8 + j] = frag_in_imag.x[j] = in_ele.y;
//         }

//         complex_mul(frag_F_real, frag_F_imag, frag_in_real, frag_in_imag, frag_out_real, frag_out_imag);

//         for (int j = 0; j < 8; ++j)
//         {
//             int col = raw_row + j / 4 * 8 + j % 2;
//             int row = raw_col + j % 4 / 2 * 8;
//             smem_in[warp_start + row * 16 + col] = {frag_out_real.x[j], frag_out_imag.x[j]};
//         }
//     }

//     half2 twiddle_1024_1 = W_N_K(1024, t_block);
//     half2 twiddle_1024_2 = cmul(twiddle_1024_1, twiddle_1024_1);
//     half2 twiddle_1024_3 = cmul(twiddle_1024_2, twiddle_1024_1);

//     __syncthreads();
//     for (int i = 0; i < 1024 * CONT_SIZE; i += NUM_WARP * 32 * 4)
//     {
//         int eid = i + t_block;
//         half2 ele0 = smem_in[eid];
//         half2 ele1 = cmul(smem_in[eid + 256], twiddle_1024_1);
//         half2 ele2 = cmul(smem_in[eid + 512], twiddle_1024_2);
//         half2 ele3 = cmul(smem_in[eid + 768], twiddle_1024_3);
//         in[block_start + eid] = ele0 + ele1 + ele2 + ele3;
//         in[block_start + eid + 256] = ele0 + half2({ele1.y, -ele1.x}) - ele2 + half2({-ele3.y, ele3.x});
//         in[block_start + eid + 512] = ele0 - ele1 + ele2 - ele3;
//         in[block_start + eid + 768] = ele0 + half2({-ele1.y, ele1.x}) - ele2 + half2({ele3.y, -ele3.x});
//     }
// }