#ifndef META_UTIL_H
#define META_UTIL_H

#include <mma.h>
using namespace nvcuda;

template<int MWarps, int NWarps, int wTileM, int wTileN, int M, int N, int K, int TileM, int TileN>
struct MetaUtil{

    static constexpr int MWarpTiles = wTileM / M;
    static constexpr int NWarpTiles = wTileN / N;

    // The row group size is 32
    static constexpr int MSteps = wTileM / 32;
    static constexpr int MPerStep = 32 / M; // This equals to 2

    static constexpr int NSteps = wTileN / 16;
    static constexpr int NPerStep = 16 / N; // This equals to 1

    static constexpr int SHM_STRIDE = TileN / 2 + 8;

    // Static members for writing the nonzeros to DRAM
    static constexpr int SubWarpSize = TileN / 8;
    static constexpr int NumSubWarp = MWarps * NWarps * (32 / SubWarpSize);

    //
    //  Member variables
    //

    int* metadata;
    float* shared_res_ptr;
    int m;
    int n;

    float4* global_res;
    float4* shared_res;


    __device__ __forceinline__ MetaUtil(
        int16_t * metadata_, float* smem, float* nonzeros, int m_offset, int n_offset, int m_, int n_)
    {
        int warpId = threadIdx.x / 32;
        int laneId = threadIdx.x % 32;
        int MwarpId = warpId / NWarps;
        int NwarpId = warpId % NWarps;
        m = m_;
        n = n_;

        metadata = reinterpret_cast<int*>(metadata_) + ((n_offset / 16) + NwarpId * (wTileN / 16)) * m_ + m_offset + wTileM * MwarpId;

        // For writing the result from fragment to shared memory
        shared_res_ptr = smem + (MwarpId * wTileM + laneId / 4) * SHM_STRIDE + NwarpId * wTileN / 2 + laneId % 4;

        // For writing the result from shared memory to fragment
        int sublaneId = threadIdx.x % SubWarpSize;
        int subwarpId = threadIdx.x / SubWarpSize;

        global_res = reinterpret_cast<float4 *>(nonzeros + (m_offset + subwarpId) * (n / 2) + n_offset / 2) + sublaneId;
        shared_res = reinterpret_cast<float4 *>(smem + subwarpId * SHM_STRIDE) + sublaneId;
    }


    __device__ __forceinline__ void get_meta_data(wmma::fragment<wmma::accumulator, M, N, K, float> c[][NWarpTiles]){
        int th_group_id = threadIdx.x % 4;
        int laneId = threadIdx.x % 32;
        #pragma unroll
        for (int m_step = 0; m_step < MSteps; m_step ++){
            #pragma unroll
            for (int n_step = 0; n_step < NSteps; n_step ++){
                // Each Step processes a 32 x 16 Tile.
                // Step 1: prune the result in the register file
                int16_t meta[8] = {0};
                #pragma unroll
                for (int i = 0; i < 2; i++){
                    int m_i = m_step * 2 + i;
                    int n_j = n_step;

                    #pragma unroll
                    for (int j = 0; j < 4; j++){
                        // Get the meta data
                        meta[4*i + j] = (c[m_i][n_j].x[2 * j] > c[m_i][n_j].x[2 * j + 1] ? 4 : 14) << (th_group_id * 4);
                        // Get the larger value
                        float value = c[m_i][n_j].x[2 * j] > c[m_i][n_j].x[2 * j + 1] ? c[m_i][n_j].x[2 * j] : c[m_i][n_j].x[2 * j + 1];
                        // Write the value to Shared memory
                        *(shared_res_ptr + (m_i * M + (j % 2) * (M/2)) * SHM_STRIDE + n_j * (N / 2) + (j / 2) * (N / 4)) = value;
                    }
                    // Collect the meta data at the target thread.
                    #pragma unroll
                    for (int j = 0; j < 4; j++){
                        if (i == 0){
                            meta[j] |= __shfl_down_sync(0xffffffff, meta[j], 2);
                        }else{
                            meta[j + 4] |= __shfl_up_sync(0xffffffff, meta[j + 4], 2);
                        }
                        if (j < 2){
                            meta[i * 4 + j] |= __shfl_down_sync(0xffffffff, meta[i * 4 + j], 1);
                        }else{
                            meta[i * 4 + j] |= __shfl_up_sync(0xffffffff, meta[i * 4 + j], 1);
                        }
                    }
                }
                // All the meta data are already collected in meta[8]
                // Step 2: Switch the meta data (Nothing needs to be done)
                // Step 3: vectorize
                int* meta_vec = reinterpret_cast<int *>(meta);
                // Step 4: put the meta data to the first element of the vec
                if (th_group_id == 1) meta_vec[0] = meta_vec[1];
                else if (th_group_id == 2) meta_vec[0] = meta_vec[2];
                else if (th_group_id == 3) meta_vec[0] = meta_vec[3];

                *(metadata + m_step * 32 + n_step * m + laneId) = meta_vec[0];
            }
        }
    }

    __device__ __forceinline__ void write_nonzeros(){
        #pragma unroll
        for (int i = 0; i < TileM / NumSubWarp; i++){
            *(global_res) = *(shared_res);
            global_res += NumSubWarp * n / 8;
            shared_res += NumSubWarp * SHM_STRIDE / 4;
        }
    }

};



template<int MWarps, int NWarps, int wTileM, int wTileN, int M, int N, int K, int TileM, int TileN>
struct MetaUtil_bf16{

    static constexpr int MWarpTiles = wTileM / M;
    static constexpr int NWarpTiles = wTileN / N;

    // The row group size is 32
    static constexpr int MSteps = wTileM / 32;
    static constexpr int MPerStep = 32 / M; // This equals to 2

    // The column group size is also 32
    static constexpr int NSteps = wTileN / 32;
    static constexpr int NPerStep = 32 / N; // This equals to 1

    static constexpr int SHM_STRIDE = TileN / 4 + 8;

    // Static members for writing the nonzeros to DRAM
    static constexpr int SubWarpSize = TileN / 16;
    static constexpr int NumSubWarp = MWarps * NWarps * (32 / SubWarpSize);

    //
    //  Member variables
    //

    int* metadata;
    float* shared_res_ptr;
    int m;
    int n;

    float4* global_res;
    float4* shared_res;


    __device__ __forceinline__ MetaUtil_bf16(
        int16_t * metadata_, float* smem, nv_bfloat16* nonzeros, int m_offset, int n_offset, int m_, int n_)
    {
        int warpId = threadIdx.x / 32;
        int laneId = threadIdx.x % 32;
        int MwarpId = warpId / NWarps;
        int NwarpId = warpId % NWarps;
        m = m_;
        n = n_;

        metadata = reinterpret_cast<int*>(metadata_) + ((n_offset / 32) + NwarpId * (wTileN / 32)) * m_ + m_offset + wTileM * MwarpId;

        // For writing the result from fragment to shared memory
        shared_res_ptr = smem + (MwarpId * wTileM + laneId / 4) * SHM_STRIDE + NwarpId * wTileN / 4 + laneId % 4;

        // For writing the result from shared memory to fragment
        int sublaneId = threadIdx.x % SubWarpSize;
        int subwarpId = threadIdx.x / SubWarpSize;

        global_res = reinterpret_cast<float4 *>(nonzeros + (m_offset + subwarpId) * (n / 2) + n_offset / 2) + sublaneId;
        shared_res = reinterpret_cast<float4 *>(smem + subwarpId * SHM_STRIDE) + sublaneId;
    }


    __device__ __forceinline__ void get_meta_data(wmma::fragment<wmma::accumulator, M, N, K, float> c[][NWarpTiles]){
        int th_group_id = threadIdx.x % 4;
        int laneId = threadIdx.x % 32;
        #pragma unroll
        for (int m_step = 0; m_step < MSteps; m_step ++){
            #pragma unroll
            for (int n_step = 0; n_step < NSteps; n_step ++){
                // Each Step processes a 32 x 32 Tile.
                // Step 1: prune the result in the register file
                int16_t meta[8] = {0};
                #pragma unroll
                for (int i = 0; i < 2; i++){
                    for (int j = 0; j < 2; j++){
                        int m_i = m_step * 2 + i;
                        int n_j = n_step * 2 + j;

                        for (int k = 0; k < 2; k++){

                            float data[4] = {c[m_i][n_j].x[0 + 2 * k], c[m_i][n_j].x[1 + 2 * k], c[m_i][n_j].x[4 + 2 * k], c[m_i][n_j].x[5 + 2 * k]};
                            nv_bfloat16 data_bf16[4] = {
                                __float2bfloat16(data[0]),  __float2bfloat16(data[1]), __float2bfloat16(data[2]),  __float2bfloat16(data[3])
                            };

                            nv_bfloat16 value[2] = {data_bf16[0], data_bf16[1]};
                            int16_t meta_bit = 4;
                            float max_val = data[0] + data[1];

                            if(data[0] + data[2] > max_val){
                                meta_bit = 8;
                                value[1] = data_bf16[2];
                                max_val = data[0] + data[2];
                            }

                            if(data[0] + data[3] > max_val){
                                meta_bit = 12;
                                value[1] = data_bf16[3];
                                max_val = data[0] + data[3];
                            }

                            if(data[1] + data[2] > max_val){
                                meta_bit = 9;
                                value[0] = data_bf16[1];
                                value[1] = data_bf16[2];
                                max_val = data[1] + data[2];
                            }

                            if(data[1] + data[3] > max_val){
                                meta_bit = 13;
                                value[0] = data_bf16[1];
                                value[1] = data_bf16[3];
                                max_val = data[1] + data[3];
                            }

                            if(data[2] + data[3] > max_val){
                                meta_bit = 14;
                                value[0] = data_bf16[2];
                                value[1] = data_bf16[3];
                            }

                            meta[4 * i + 2 * j + k] = meta_bit << (th_group_id * 4);

                            // Write the value to Shared memory
                            *(shared_res_ptr + (m_i * M + k * (M/2)) * SHM_STRIDE + n_j * (N/4)) = *reinterpret_cast<float*>(value);
                        }

                        // Collect the meta dat at the target thread.
                        #pragma unroll
                        for (int k = 0; k < 2; k++){
                            if (i == 0){
                                meta[2 * j + k] |= __shfl_down_sync(0xffffffff, meta[2 * j + k], 2); 
                            }else{
                                meta[2 * j + k + 4] |= __shfl_up_sync(0xffffffff, meta[2 * j + k + 4], 2);
                            }
                            if (j == 0){
                                meta[i * 4 + k] |= __shfl_down_sync(0xffffffff, meta[i * 4 + k], 1);
                            }else{
                                meta[i * 4 + 2 + k] |= __shfl_up_sync(0xffffffff, meta[i * 4 + 2 + k], 1);
                            }
                        }
                    }
                }
                // All the meta data are already collected in meta[8]
                // Step 2: Switch the meta data (Nothing needs to be done)
                // Step 3: vectorize
                int* meta_vec = reinterpret_cast<int *>(meta);
                // Step 4: put the meta data to the first element of the vec
                if (th_group_id == 1) meta_vec[0] = meta_vec[1];
                else if (th_group_id == 2) meta_vec[0] = meta_vec[2];
                else if (th_group_id == 3) meta_vec[0] = meta_vec[3];

                *(metadata + m_step * 32 + n_step * m + laneId) = meta_vec[0];
            }
        }
    }

    __device__ __forceinline__ void write_nonzeros(){
        #pragma unroll
        for (int i = 0; i < TileM / NumSubWarp; i++){
            *(global_res) = *(shared_res);
            global_res += NumSubWarp * n / 16;
            shared_res += NumSubWarp * SHM_STRIDE / 4;
        }
    }

};




#endif