#ifndef GEMM_SRC_ITERATOR_H
#define GEMM_SRC_ITERATOR_H

#include "cutlass/cutlass.h"
#include "cutlass/arch/memory_sm75.h"
#include <cuda/pipeline>



// The tile size of a warp is TileM * TileN * TileK
// Element: the scalar type of the element {float, half, ...}
// AlignK: The length of the vector type. E.g. when loading float with float4, AlignK = 4
// NumWarp: number of warps per thread block


template <int TileM, int TileN, int TileK, int AlignK, int NumWarp, int STAGES>
struct SRCIterator
{
    //
    //  Static Members
    //
    // static_assert(STAGES == 2, "Currently only support 2 stages");

    // Each subwarp will load a row with length TileK from the global memory
    static constexpr int SubWarpSize = TileK / AlignK;
    static_assert(SubWarpSize <= 32, "TileK / AlignK should be smaller than 32");
    static_assert(TileK % AlignK == 0, "TileK should be multiple of AlignK");
    static_assert(TileK == 32, "Currently only support TileK = 32");
    static_assert(AlignK == 4, "Currently only support float4");

    static constexpr int NumSubWarp = NumWarp * 32 / SubWarpSize;
    static_assert(NumSubWarp <= TileM, "Number of subwarp should be <= TileM");
    static_assert(NumSubWarp <= TileN, "Number of subwarp should be <= TileN");

    static constexpr int LHSFragOffset = 0;
    static constexpr int RHSFragOffset = TileM * TileK;
    static constexpr int BatchOffset = (TileM + TileN) * TileK;
    //
    //  Member variables
    //

    const float* lhs_matrix_ptr;
    const float* rhs_matrix_ptr;
    float* lhs_fragment_ptr;
    float* rhs_fragment_ptr;
    int k;

    __device__ __forceinline__ SRCIterator(
        const float* lhs_matrix, const float* rhs_matrix, float* smem,
        int m_offset, int n_offset, int k_)
    {
        k = k_;
        int subwarpId = threadIdx.x / SubWarpSize;
        int sublaneId = threadIdx.x % SubWarpSize;

        int row_group_id = subwarpId % 8;
        int skew = sublaneId ^ row_group_id;

        lhs_matrix_ptr = lhs_matrix + (m_offset + subwarpId) * k + sublaneId * AlignK;
        // lhs_fragment_ptr = smem + LHSFragOffset + subwarpId * TileK + (subwarpId + sublaneId) % SubWarpSize * AlignK;
        lhs_fragment_ptr = smem + LHSFragOffset + subwarpId * TileK + skew * AlignK;

        rhs_matrix_ptr = rhs_matrix + (n_offset + subwarpId) * k + sublaneId * AlignK;
        // rhs_fragment_ptr = smem + RHSFragOffset + subwarpId * TileK + (subwarpId + sublaneId) % SubWarpSize * AlignK;
        rhs_fragment_ptr = smem + RHSFragOffset + subwarpId * TileK + skew * AlignK;
    }

    __device__ __forceinline__ void Load_async(int batch_idx){
        int shared_idx = batch_idx % STAGES;
        // Get the pointers
        const float* lhs_matrix_t = lhs_matrix_ptr + batch_idx * TileK;
        const float* rhs_matrix_t = rhs_matrix_ptr + batch_idx * TileK;

        float* lhs_fragment_t = lhs_fragment_ptr + shared_idx * BatchOffset;
        float* rhs_fragment_t = rhs_fragment_ptr + shared_idx * BatchOffset;

        #pragma unroll
        for (int step_m = 0; step_m < TileM / NumSubWarp; step_m ++){
            unsigned lhs_fragment_offset_t = cutlass::arch::cutlass_get_smem_pointer(lhs_fragment_t);
            asm("cp.async.cg.shared.global [%0], [%1], %2;\n" :: "r"(lhs_fragment_offset_t), "l"(lhs_matrix_t), "n"(16));
            lhs_matrix_t += NumSubWarp * k;
            lhs_fragment_t += NumSubWarp * TileK;
        }
        #pragma unroll
        for (int step_n = 0; step_n < TileN / NumSubWarp; step_n ++){
            unsigned rhs_fragment_offset_t = cutlass::arch::cutlass_get_smem_pointer(rhs_fragment_t);
            asm("cp.async.cg.shared.global [%0], [%1], %2;\n" :: "r"(rhs_fragment_offset_t), "l"(rhs_matrix_t), "n"(16));
            rhs_matrix_t += NumSubWarp * k;
            rhs_fragment_t += NumSubWarp * TileK;
        }
    }
};



template <int TileM, int TileN, int TileK, int AlignK, int NumWarp, int STAGES>
struct SRCIteratorInterleaved
{
    //
    //  Static Members
    //
    // static_assert(STAGES == 2, "Currently only support 2 stages");

    // Each subwarp will load a row with length TileK from the global memory
    static constexpr int SubWarpSize = TileK / AlignK;
    static_assert(SubWarpSize <= 32, "TileK / AlignK should be smaller than 32");
    static_assert(TileK % AlignK == 0, "TileK should be multiple of AlignK");
    static_assert(TileK == 32, "Currently only support TileK = 32");
    static_assert(AlignK == 4, "Currently only support float4");

    static constexpr int NumSubWarp = NumWarp * 32 / SubWarpSize;
    static_assert(NumSubWarp <= TileM, "Number of subwarp should be <= TileM");
    static_assert(NumSubWarp <= TileN, "Number of subwarp should be <= TileN");

    static constexpr int LHSFragOffset = 0;
    static constexpr int RHSFragOffset = TileM * TileK;
    static constexpr int BatchOffset = (TileM + TileN) * TileK;
    //
    //  Member variables
    //

    const float* lhs_matrix_ptr;
    const float* rhs_matrix_ptr;
    float* lhs_fragment_ptr;
    float* rhs_fragment_ptr;
    int k;

    __device__ __forceinline__ SRCIteratorInterleaved(
        const float* lhs_matrix, const float* rhs_matrix, float* smem,
        int m_offset, int n_offset, int k_)
    {
        k = k_;
        int subwarpId = threadIdx.x / SubWarpSize;
        int sublaneId = threadIdx.x % SubWarpSize;

        int row_group_id = subwarpId % 8;
        int skew = sublaneId ^ row_group_id;

        // TODO: simplify this expression
        int global_col_id = ((subwarpId / 2) % 4) * 4 + (subwarpId % 2) + ((subwarpId / 8) % 2) * 2 + (subwarpId / 16) * 16;

        lhs_matrix_ptr = lhs_matrix + (m_offset + subwarpId) * k + sublaneId * AlignK;
        // lhs_fragment_ptr = smem + LHSFragOffset + subwarpId * TileK + (subwarpId + sublaneId) % SubWarpSize * AlignK;
        lhs_fragment_ptr = smem + LHSFragOffset + subwarpId * TileK + skew * AlignK;

        rhs_matrix_ptr = rhs_matrix + (n_offset + global_col_id) * k + sublaneId * AlignK;
        // rhs_fragment_ptr = smem + RHSFragOffset + subwarpId * TileK + (subwarpId + sublaneId) % SubWarpSize * AlignK;
        rhs_fragment_ptr = smem + RHSFragOffset + subwarpId * TileK + skew * AlignK;
    }

    __device__ __forceinline__ void Load_async(int batch_idx){
        int shared_idx = batch_idx % STAGES;
        // Get the pointers
        const float* lhs_matrix_t = lhs_matrix_ptr + batch_idx * TileK;
        const float* rhs_matrix_t = rhs_matrix_ptr + batch_idx * TileK;

        float* lhs_fragment_t = lhs_fragment_ptr + shared_idx * BatchOffset;
        float* rhs_fragment_t = rhs_fragment_ptr + shared_idx * BatchOffset;

        #pragma unroll
        for (int step_m = 0; step_m < TileM / NumSubWarp; step_m ++){
            unsigned lhs_fragment_offset_t = cutlass::arch::cutlass_get_smem_pointer(lhs_fragment_t);
            asm("cp.async.cg.shared.global [%0], [%1], %2;\n" :: "r"(lhs_fragment_offset_t), "l"(lhs_matrix_t), "n"(16));
            lhs_matrix_t += NumSubWarp * k;
            lhs_fragment_t += NumSubWarp * TileK;
        }
        #pragma unroll
        for (int step_n = 0; step_n < TileN / NumSubWarp; step_n ++){
            unsigned rhs_fragment_offset_t = cutlass::arch::cutlass_get_smem_pointer(rhs_fragment_t);
            asm("cp.async.cg.shared.global [%0], [%1], %2;\n" :: "r"(rhs_fragment_offset_t), "l"(rhs_matrix_t), "n"(16));
            rhs_matrix_t += NumSubWarp * k;
            rhs_fragment_t += NumSubWarp * TileK;
        }
    }
};


#endif