#ifndef GEMM_MASK_UTIL_H
#define GEMM_MASK_UTIL_H


template <bool Mask, int TileM, int TileN, int NumWarp, typename LoadType, int M, int N, int K, int wTileM, int wTileN, int MWarps, int NWarps>
struct MaskUtil
{
    static constexpr int AlignN = sizeof(LoadType) / sizeof(float);
    static constexpr int SubWarpSize = TileN / AlignN;
    static constexpr int NumSubWarp = NumWarp * 32 / SubWarpSize;
    static constexpr int SHM_STRIDE = TileN * NumSubWarp / AlignN;
    static constexpr int GLOBAL_STRIDE = NumSubWarp / AlignN;

    static constexpr int MWarpTiles = wTileM / M;
    static constexpr int NWarpTiles = wTileN / N;

    const LoadType* mask_ptr;
    LoadType* shm_mask_ptr;
    const float* w_tile_ptr;
    int n;

    __device__ __forceinline__ MaskUtil(
        const float* mask, float* smem, int m_offset, int n_offset, int n_)
    {
        if (Mask){
            int subwarpId = threadIdx.x / SubWarpSize;
            int sublaneId = threadIdx.x % SubWarpSize;

            // The pointer to the mask matrix to copy memory from to shared memory
            n = n_;
            mask_ptr = reinterpret_cast<const LoadType*>(mask + (m_offset + subwarpId) * n + n_offset) + sublaneId;
            shm_mask_ptr = reinterpret_cast<LoadType*>(smem + subwarpId * TileN) + sublaneId;

            // load the mask from the shared memory to fragment
            int warpId = threadIdx.x / 32;
            int MwarpId = warpId / NWarps;
            int NwarpId = warpId % NWarps;

            int mw = MwarpId * wTileM;
            int nw = NwarpId * wTileN;

            w_tile_ptr = smem + mw * TileN + nw;
        }
    }

    __device__ __forceinline__ void load_mask(wmma::fragment<wmma::accumulator, M, N, K, float> c[][NWarpTiles]){
        if (Mask){
            #pragma unroll
            for (int step_m = 0; step_m < TileM / NumSubWarp; step_m ++){
                *shm_mask_ptr = *mask_ptr;
                shm_mask_ptr += SHM_STRIDE;
                mask_ptr += n * GLOBAL_STRIDE;
            }
            __syncthreads();

            #pragma unroll
            for (int i = 0; i < MWarpTiles; i++){
                const float* w_tile_ptr_ = w_tile_ptr + i * M * TileN;
                #pragma unroll
                for (int j=0; j < NWarpTiles; j++){
                    wmma::load_matrix_sync(c[i][j], w_tile_ptr_, TileN, wmma::mem_row_major);
                    w_tile_ptr_ += N;
                }
            }
            __syncthreads();
            
        } else {
            #pragma unroll
            for (int i = 0; i < MWarpTiles; i++){
                #pragma unroll
                for (int j = 0; j < NWarpTiles; j++){
                    wmma::fill_fragment(c[i][j], 0.0f);
                }
            }
        }
    }
};


template <bool Mask, typename LoadType, int TileN, int MWarps, int NWarps, int wTileM, int wTileN, int M, int N, int K>
struct SeqMaskUtil
{
    static constexpr int AlignN = sizeof(LoadType) / sizeof(float);
    static constexpr int SubWarpSize = TileN / AlignN;
    
    static constexpr int MWarpTiles = wTileM / M;
    static constexpr int NWarpTiles = wTileN / N;
    
    const LoadType* mask_ptr;
    LoadType* shm_mask_ptr;
    const float2* w_tile_ptr;

    __device__ __forceinline__ SeqMaskUtil(
        const float* mask, float* smem, int n_offset)
    {
        if (Mask){
            for (int i = threadIdx.x; i < SubWarpSize; i++){
                *(reinterpret_cast<LoadType*>(smem) + i) = *(reinterpret_cast<const LoadType*>(mask + n_offset) + i);
            }

            int warpId = threadIdx.x / 32;
            int laneId = threadIdx.x % 4;
            // int MwarpId = warpId / NWarps;
            int NwarpId = warpId % NWarps;

            int nw = NwarpId * wTileN;
            w_tile_ptr = reinterpret_cast<const float2*>(smem + nw) + laneId;
        }
    }

    __device__ __forceinline__ void load_mask(wmma::fragment<wmma::accumulator, M, N, K, float> c[][NWarpTiles]){
        if (Mask){
            __syncthreads();

            const float2* w_tile_ptr_ = w_tile_ptr;

            #pragma unroll
            for (int j=0; j < NWarpTiles; j++){
                *(reinterpret_cast<float2*>(c[0][j].x)) = *(w_tile_ptr_);
                *(reinterpret_cast<float2*>(c[0][j].x) + 2) = *(w_tile_ptr_ + 4);
                w_tile_ptr_ += N/2;
            }

            #pragma unroll
            for (int j=0; j < NWarpTiles; j++){
                c[0][j].x[2] = c[0][j].x[0];
                c[0][j].x[3] = c[0][j].x[1];
                c[0][j].x[6] = c[0][j].x[4];
                c[0][j].x[7] = c[0][j].x[5];
            }

            #pragma unroll
            for (int i=1; i < MWarpTiles; i++){
                #pragma unroll
                for (int j=0; j < NWarpTiles; j++){
                    #pragma unroll
                    for (int r=0; r < 8; r++){
                        c[i][j].x[r] = c[0][j].x[r];
                    }
                }
            }
            __syncthreads();
        } else {
            #pragma unroll
            for (int i = 0; i < MWarpTiles; i++){
                #pragma unroll
                for (int j = 0; j < NWarpTiles; j++){
                    wmma::fill_fragment(c[i][j], 0.0f);
                }
            }
        }
    }
};


template <bool Mask, typename LoadType, int TileN, int MWarps, int NWarps, int wTileM, int wTileN, int M, int N, int K>
struct SeqMaskUtil_bf16
{
    static constexpr int AlignN = sizeof(LoadType) / sizeof(nv_bfloat16);
    static constexpr int SubWarpSize = TileN / AlignN;
    
    static constexpr int MWarpTiles = wTileM / M;
    static constexpr int NWarpTiles = wTileN / N;
    
    const LoadType* mask_ptr;
    LoadType* shm_mask_ptr;
    const float2* w_tile_ptr;

    __device__ __forceinline__ SeqMaskUtil_bf16(
        const nv_bfloat16* mask, float* smem, int n_offset)
    {
        if (Mask){
            for (int i = threadIdx.x; i < SubWarpSize; i++){
                *(reinterpret_cast<LoadType*>(smem) + i) = *(reinterpret_cast<const LoadType*>(mask + n_offset) + i);
            }

            int warpId = threadIdx.x / 32;
            int laneId = threadIdx.x % 4;
            int NwarpId = warpId % NWarps;

            int nw = NwarpId * wTileN / 2;
            w_tile_ptr = reinterpret_cast<const float2*>(smem + nw) + laneId;
        }
    }

    __device__ __forceinline__ void load_mask(wmma::fragment<wmma::accumulator, M, N, K, float> c[][NWarpTiles]){
        if (Mask){
            __syncthreads();

            const float2* w_tile_ptr_ = w_tile_ptr;

            #pragma unroll
            for (int j=0; j < NWarpTiles; j++){
                *(reinterpret_cast<float2*>(c[0][j].x)+1) = *(w_tile_ptr_);
                w_tile_ptr_ += N/4;
            }

            #pragma unroll
            for (int j=0; j < NWarpTiles; j++){
                nv_bfloat16* c_bf16 = reinterpret_cast<nv_bfloat16*>(c[0][j].x + 2);

                c[0][j].x[0] = __bfloat162float(c_bf16[0]);
                c[0][j].x[1] = __bfloat162float(c_bf16[1]);
                c[0][j].x[4] = __bfloat162float(c_bf16[2]);
                c[0][j].x[5] = __bfloat162float(c_bf16[3]);
                c[0][j].x[2] = c[0][j].x[0];
                c[0][j].x[3] = c[0][j].x[1];
                c[0][j].x[6] = c[0][j].x[4];
                c[0][j].x[7] = c[0][j].x[5];
            }

            #pragma unroll
            for (int i=1; i < MWarpTiles; i++){
                #pragma unroll
                for (int j=0; j < NWarpTiles; j++){
                    #pragma unroll
                    for (int r=0; r < 8; r++){
                        c[i][j].x[r] = c[0][j].x[r];
                    }
                }
            }
            __syncthreads();
        } else {
            #pragma unroll
            for (int i = 0; i < MWarpTiles; i++){
                #pragma unroll
                for (int j = 0; j < NWarpTiles; j++){
                    wmma::fill_fragment(c[i][j], 0.0f);
                }
            }
        }
    }
};

/*
template <bool Mask, int TileM, int TileN, int NumWarp, typename LoadType, int M, int N, int K, int wTileM, int wTileN, int MWarps, int NWarps>
struct MaskUtil_bf16
{
    static constexpr int AlignN = sizeof(LoadType) / sizeof(nv_bfloat16);
    static constexpr int SubWarpSize = TileN / AlignN;
    static constexpr int NumSubWarp = NumWarp * 32 / SubWarpSize;
    static constexpr int SHM_STRIDE = TileN * NumSubWarp / AlignN;
    static constexpr int GLOBAL_STRIDE = NumSubWarp / AlignN;

    static constexpr int MWarpTiles = wTileM / M;
    static constexpr int NWarpTiles = wTileN / N;

    const LoadType* mask_ptr;
    LoadType* shm_mask_ptr;
    const nv_bfloat16* w_tile_ptr;
    int n;

    __device__ __forceinline__ MaskUtil_bf16(
        const nv_bfloat16* mask, float* smem, int m_offset, int n_offset, int n_)
    {
        if (Mask){
            int subwarpId = threadIdx.x / SubWarpSize;
            int sublaneId = threadIdx.x % SubWarpSize;

            // The pointer to the mask matrix to copy memory from to shared memory
            n = n_;
            mask_ptr = reinterpret_cast<const LoadType*>(mask + (m_offset + subwarpId) * n + n_offset) + sublaneId;
            shm_mask_ptr = reinterpret_cast<LoadType*>(smem + SubWarpSize * TileN + sublaneId * AlignN);

            // load the mask from the shared memory to fragment
            int warpId = threadIdx.x / 32;
            int MwarpId = warpId / NWarps;
            int NwarpId = warpId % NWarps;

            int m = MwarpId * wTileM;
            int n = NwarpId * wTileN;

            w_tile_ptr = reinterpret_cast<const nv_bfloat16*>(smem) + m * TileN + n;
        }
    }

    __device__ __forceinline__ void load_mask(wmma::fragment<wmma::accumulator, M, N, K, float> c[][NWarpTiles]){
        if (Mask){
            #pragma unroll
            for (int step_m = 0; step_m < TileM / NumSubWarp; step_m ++){
                *shm_mask_ptr = *mask_ptr;
                shm_mask_ptr += SHM_STRIDE;
                mask_ptr += n * GLOBAL_STRIDE;
            }
            __syncthreads();

            wmma::fragment<wmma::accumulator, M, N, K, half> c_bf16;

            #pragma unroll
            for (int i = 0; i < MWarpTiles; i++){
                const nv_bfloat16* w_tile_ptr_ = w_tile_ptr + i * M * TileN;
                #pragma unroll
                for (int j=0; j < NWarpTiles; j++){
                    wmma::load_matrix_sync(c_bf16, w_tile_ptr_, TileN, wmma::mem_row_major);
                    w_tile_ptr_ += N;
                    #pragma unroll
                    for (int t=c[i][j].num_elements - 1; t >= 0; t--){
                        c[i][j].x[t] = __bfloat162float(c_bf16.x[t]);
                    }
                }
            }
        } else {
            #pragma unroll
            for (int i = 0; i < MWarpTiles; i++){
                #pragma unroll
                for (int j = 0; j < NWarpTiles; j++){
                    wmma::fill_fragment(c[i][j], 0.0f);
                }
            }
        }
    }
};
*/

#endif