#ifndef GEMM_COMPUTER_H
#define GEMM_COMPUTER_H

#include <mma.h>
using namespace nvcuda;
#include "cutlass/cutlass.h"
#include "cutlass/arch/memory_sm75.h"

template<int MWarps, int NWarps, int TileM, int TileN, int TileK, int STAGES, int wTileM, int wTileN, int AlignK, int M, int N, int K>
struct Computer{

    static constexpr int SubWarpSize = TileK / AlignK;
    static constexpr int CHUNK_K = TileK / K;

    //
    //  Static Members
    //

    static constexpr int LHSFragOffset = 0;
    static constexpr int RHSFragOffset = TileM * TileK;
    static constexpr int BatchOffset = (TileM + TileN) * TileK;

    static constexpr int MWarpTiles = wTileM / M;
    static constexpr int NWarpTiles = wTileN / N;

    //
    //  Member variables
    //

    const float* lhs_fragment_ptr_base;
    const float* rhs_fragment_ptr_base;

    int skew;

    __device__ __forceinline__ Computer(
        const float* smem)
    {
        int warpId = threadIdx.x / 32;
        int MwarpId = warpId / NWarps;
        int NwarpId = warpId % NWarps;
        int laneId = threadIdx.x % 32;

        int m = MwarpId * wTileM + laneId % 16;
        int n = NwarpId * wTileN + laneId % 16;

        int row_group_id = laneId % 8;
        int col_offset = laneId / 16;
        skew = row_group_id ^ col_offset;


        lhs_fragment_ptr_base = smem + LHSFragOffset + m * TileK;
        rhs_fragment_ptr_base = smem + RHSFragOffset + n * TileK;
    }

    __device__ __forceinline__ void compute_block_tile(wmma::fragment<wmma::accumulator, M, N, K, float> c[][NWarpTiles], int batch_idx){
        wmma::fragment<wmma::matrix_a, M, N, K, wmma::precision::tf32, wmma::row_major> a[MWarpTiles];
        wmma::fragment<wmma::matrix_b, M, N, K, wmma::precision::tf32, wmma::col_major> b[NWarpTiles];

        int shared_idx = batch_idx % STAGES;

        const float* lhs_fragment_ptr = lhs_fragment_ptr_base + shared_idx * BatchOffset;
        const float* rhs_fragment_ptr = rhs_fragment_ptr_base + shared_idx * BatchOffset;
        #pragma unroll
        for (int k_step = 0; k_step < CHUNK_K; k_step ++){
            int k_step2 = k_step * 2;
            int skew_t = skew ^ k_step2;
            const float* lhs_fragment_t = lhs_fragment_ptr + skew_t * AlignK;
            const float* rhs_fragment_t = rhs_fragment_ptr + skew_t * AlignK;

            #pragma unroll
            for (int i = 0; i < MWarpTiles; i++){
                unsigned shared_lhs_offset_t = cutlass::arch::cutlass_get_smem_pointer(lhs_fragment_t);
                int* a_int = reinterpret_cast<int *>(a[i].x);
                asm volatile ("ldmatrix.sync.aligned.m8n8.x4.shared.b16 {%0, %1, %2, %3}, [%4];" : "=r"(a_int[0]), "=r"(a_int[1]), "=r"(a_int[2]), "=r"(a_int[3]): "r"(shared_lhs_offset_t));
                #pragma unroll
                for (int t = 0; t < a[i].num_elements; t++){
                    a[i].x[t] = wmma::__float_to_tf32(a[i].x[t]);
                }

                #pragma unroll
                for (int j = 0; j < NWarpTiles; j++){
                    int* b_int = reinterpret_cast<int *>(b[j].x);

                    if (i == 0){
                        unsigned shared_rhs_offset_t = cutlass::arch::cutlass_get_smem_pointer(rhs_fragment_t);
                        asm volatile ("ldmatrix.sync.aligned.m8n8.x4.shared.b16 {%0, %1, %2, %3}, [%4];" : "=r"(b_int[0]), "=r"(b_int[2]), "=r"(b_int[1]), "=r"(b_int[3]): "r"(shared_rhs_offset_t));
                        #pragma unroll
                        for (int t = 0; t < b[j].num_elements; t++){
                            b[j].x[t] = wmma::__float_to_tf32(b[j].x[t]);
                        }
                        rhs_fragment_t += N * TileK;
                    }
                    float *c_float = c[i][j].x;

                    asm volatile ("mma.sync.aligned.m16n8k8.row.col.f32.tf32.tf32.f32 {%0, %1, %2, %3}, {%4, %5, %6, %7}, {%8, %9}, {%0, %1, %2, %3};"
                                    : "+f"(c_float[0]), "+f"(c_float[1]), "+f"(c_float[2]), "+f"(c_float[3])
                                    : "r"(a_int[0]), "r"(a_int[1]), "r"(a_int[2]), "r"(a_int[3]), "r"(b_int[0]), "r"(b_int[1]));
                        
                    asm volatile ("mma.sync.aligned.m16n8k8.row.col.f32.tf32.tf32.f32 {%0, %1, %2, %3}, {%4, %5, %6, %7}, {%8, %9}, {%0, %1, %2, %3};"
                                    : "+f"(c_float[4]), "+f"(c_float[5]), "+f"(c_float[6]), "+f"(c_float[7])
                                    : "r"(a_int[0]), "r"(a_int[1]), "r"(a_int[2]), "r"(a_int[3]), "r"(b_int[2]), "r"(b_int[3]));
                }

                lhs_fragment_t += M * TileK;
            }
        }
    }

    __device__ __forceinline__ void compute_block_tile_bf16(wmma::fragment<wmma::accumulator, M, N, K, float> c[][NWarpTiles], int batch_idx){
        wmma::fragment<wmma::matrix_a, M, N, K, wmma::precision::tf32, wmma::row_major> a[MWarpTiles];
        wmma::fragment<wmma::matrix_b, M, N, K, wmma::precision::tf32, wmma::col_major> b[NWarpTiles];

        int shared_idx = batch_idx % STAGES;

        const float* lhs_fragment_ptr = lhs_fragment_ptr_base + shared_idx * BatchOffset;
        const float* rhs_fragment_ptr = rhs_fragment_ptr_base + shared_idx * BatchOffset;
        #pragma unroll
        for (int k_step = 0; k_step < CHUNK_K; k_step ++){
            int k_step2 = k_step * 2;
            int skew_t = skew ^ k_step2;
            const float* lhs_fragment_t = lhs_fragment_ptr + skew_t * AlignK;
            const float* rhs_fragment_t = rhs_fragment_ptr + skew_t * AlignK;

            #pragma unroll
            for (int i = 0; i < MWarpTiles; i++){
                unsigned shared_lhs_offset_t = cutlass::arch::cutlass_get_smem_pointer(lhs_fragment_t);
                int* a_int = reinterpret_cast<int *>(a[i].x);
                asm volatile ("ldmatrix.sync.aligned.m8n8.x4.shared.b16 {%0, %1, %2, %3}, [%4];" : "=r"(a_int[0]), "=r"(a_int[1]), "=r"(a_int[2]), "=r"(a_int[3]): "r"(shared_lhs_offset_t));
                #pragma unroll
                for (int j = 0; j < NWarpTiles; j++){
                    int* b_int = reinterpret_cast<int *>(b[j].x);

                    if (i == 0){
                        unsigned shared_rhs_offset_t = cutlass::arch::cutlass_get_smem_pointer(rhs_fragment_t);
                        asm volatile ("ldmatrix.sync.aligned.m8n8.x4.shared.b16 {%0, %1, %2, %3}, [%4];" : "=r"(b_int[0]), "=r"(b_int[2]), "=r"(b_int[1]), "=r"(b_int[3]): "r"(shared_rhs_offset_t));
                        rhs_fragment_t += N * TileK;
                    }
                    float *c_float = c[i][j].x;

                    asm volatile ("mma.sync.aligned.m16n8k16.row.col.f32.bf16.bf16.f32 {%0, %1, %2, %3}, {%4, %5, %6, %7}, {%8, %9}, {%0, %1, %2, %3};"
                                    : "+f"(c_float[0]), "+f"(c_float[1]), "+f"(c_float[2]), "+f"(c_float[3])
                                    : "r"(a_int[0]), "r"(a_int[1]), "r"(a_int[2]), "r"(a_int[3]), "r"(b_int[0]), "r"(b_int[1]));
                        
                    asm volatile ("mma.sync.aligned.m16n8k16.row.col.f32.bf16.bf16.f32 {%0, %1, %2, %3}, {%4, %5, %6, %7}, {%8, %9}, {%0, %1, %2, %3};"
                                    : "+f"(c_float[4]), "+f"(c_float[5]), "+f"(c_float[6]), "+f"(c_float[7])
                                    : "r"(a_int[0]), "r"(a_int[1]), "r"(a_int[2]), "r"(a_int[3]), "r"(b_int[2]), "r"(b_int[3]));
                }

                lhs_fragment_t += M * TileK;
            }
        }
    }
};

#endif