#ifndef GEMM_STORE_UTIL_H
#define GEMM_STORE_UTIL_H

#include <mma.h>
using namespace nvcuda;

template<int MWarps, int NWarps, int wTileM, int wTileN, int TileN, int M, int N, int K>
struct ResStore{

    static constexpr int SHM_STRIDE = TileN + 8;
    static constexpr int MWarpTiles = wTileM / M;
    static constexpr int NWarpTiles = wTileN / N;
    //
    //  Member variables
    //

    float* shared_res_ptr;

    __device__ __forceinline__ ResStore(
        float* smem)
    {
        int warpId = threadIdx.x / 32;
        int MwarpId = warpId / NWarps;
        int NwarpId = warpId % NWarps;

        shared_res_ptr = smem + MwarpId * wTileM * SHM_STRIDE + NwarpId * wTileN;
    }

    __device__ __forceinline__ void store_block_res(wmma::fragment<wmma::accumulator, M, N, K, float> c[][NWarpTiles]){
        #pragma unroll
        for (int i = 0; i < MWarpTiles; i++){
            float* shared_res_ptr_t = shared_res_ptr;
            #pragma unroll
            for (int j = 0; j < NWarpTiles; j++){
                wmma::store_matrix_sync(shared_res_ptr_t, c[i][j], SHM_STRIDE, wmma::mem_row_major);
                shared_res_ptr_t += N;
            }
            shared_res_ptr += M * SHM_STRIDE;
        }
    }
};

// *******************************************************************
// * Store the interleaved results in register file to shared memory *
// *******************************************************************

// Each warp stores its own tiles into the shared memory

template<int MWarps, int NWarps, int wTileM, int wTileN, int TileN, int M, int N, int K>
struct ResStoreInterleaved{

    static constexpr int SHM_STRIDE = TileN + 8;
    static constexpr int MWarpTiles = wTileM / M;
    static constexpr int NWarpTiles = wTileN / N;
    //
    //  Member variables
    //

    float* shared_res_ptr;

    __device__ __forceinline__ ResStoreInterleaved(
        float* smem)
    {
        int warpId = threadIdx.x / 32;
        int MwarpId = warpId / NWarps;
        int NwarpId = warpId % NWarps;
        int laneId = threadIdx.x % 32;

        shared_res_ptr = smem + (MwarpId * wTileM + laneId / 4) * SHM_STRIDE + NwarpId * wTileN + (laneId % 4) * 4;
    }

    __device__ __forceinline__ void store_block_res(wmma::fragment<wmma::accumulator, M, N, K, float> c[][NWarpTiles]){
        #pragma unroll
        for (int i = 0; i < MWarpTiles; i++){
            float* shared_res_ptr_t = shared_res_ptr;
            #pragma unroll
            for (int j = 0; j < NWarpTiles; j++){
                *(shared_res_ptr_t) = c[i][j].x[0];
                *(shared_res_ptr_t + 1) = c[i][j].x[1];
                *(shared_res_ptr_t + 2) = c[i][j].x[4];
                *(shared_res_ptr_t + 3) = c[i][j].x[5];

                *(shared_res_ptr_t + SHM_STRIDE * 8) = c[i][j].x[2];
                *(shared_res_ptr_t + 1 + SHM_STRIDE * 8) = c[i][j].x[3];
                *(shared_res_ptr_t + 2 + SHM_STRIDE * 8) = c[i][j].x[6];
                *(shared_res_ptr_t + 3 + SHM_STRIDE * 8) = c[i][j].x[7];
                shared_res_ptr_t += N;
            }
            shared_res_ptr += M * SHM_STRIDE;
        }
    }
};

#endif