#include "kernel_cuda.h"
#include <ATen/cuda/CUDAContext.h>
#include <mma.h>

namespace GPT {
using namespace nvcuda::wmma;

#define CHECK_SHAPE(x, ...)                                                                        \
    TORCH_CHECK(x.sizes() == torch::IntArrayRef({ __VA_ARGS__ }),                                  \
                #x " must have shape (" #__VA_ARGS__ ")")

__inline__ __device__ void commit_async_cp_group() {
    asm volatile("cp.async.commit_group;\n" ::);
}

template<int n>
__inline__ __device__ void wait_async_cp_group() {
    asm volatile("cp.async.wait_group %0;\n" ::"n"(n));
}

template<uint32_t n>
__inline__ __device__ void cp_async_cg_shared_global(void* __restrict__ smem_ptr,
                                                     const void* __restrict__ gmem_ptr) {
    asm volatile("cp.async.ca.shared.global [%0], [%1], %2;\n" ::"r"(
                   static_cast<uint32_t>(__cvta_generic_to_shared(smem_ptr))),
                 "l"(gmem_ptr),
                 "n"(n));
}

template<typename scalar_t>
class TypeTraits {};

template<typename scalar, int N, int M>
__inline__ __device__ void print_matrix(const char* name,
                                        int block_x,
                                        int block_y,
                                        const scalar* matrix,
                                        int stride) {
    if (blockIdx.x != block_x || blockIdx.y != block_y) {
        return;
    }

    if (threadIdx.x == 0) {
        printf("%s on block [%d, %d]\n", name, block_x, block_y);
        for (int i = 0; i < N; i++) {
            printf("[");
            for (int j = 0; j < M; j++) {
                printf("%f, ", TypeTraits<scalar>::to_float(matrix[i * stride + j]));
            }
            printf("],\n");
        }
    }
    __syncthreads();
}

template<typename scalar_t>
struct ShortVec {
    using scalar2_t = typename TypeTraits<scalar_t>::scalar2_t;
    scalar_t elements[16 / sizeof(scalar_t)];

  public:
    __inline__ __device__ float dot(const ShortVec<scalar_t>& other) const {
        const auto* elements1 = reinterpret_cast<const scalar2_t*>(elements);
        const auto* elements2 = reinterpret_cast<const scalar2_t*>(other.elements);
        scalar2_t result = TypeTraits<scalar_t>::mul2(elements1[0], elements2[0]);

#pragma unroll
        for (int i = 1; i < 16 / sizeof(scalar2_t); i++) {
            result = TypeTraits<scalar_t>::fma2(elements1[i], elements2[i], result);
        }
        return TypeTraits<scalar_t>::add_two_to_float(result);
    }
};

template<>
class TypeTraits<half> {
  public:
    using scalar_t = half;
    using scalar2_t = half2;
    using vector_t = ShortVec<scalar_t>;
    __inline__ __device__ static scalar2_t mul2(const scalar2_t& a, const scalar2_t& b) {
        return __hmul2(a, b);
    }

    __inline__ __device__ static scalar2_t fma2(const scalar2_t& a,
                                                const scalar2_t& b,
                                                const scalar2_t& c) {
        return __hfma2(a, b, c);
    }

    __inline__ __device__ static float add_two_to_float(const scalar2_t& a) {
        return __half2float(a.x + a.y);
    }

    __inline__ __device__ static scalar_t float_to_scalar(const float& a) {
        return __float2half(a);
    }

    __inline__ __device__ static scalar2_t float_to_scalar2(const float& a) {
        return __float2half2_rn(a);
    }

    __inline__ __device__ static float to_float(const half& a) { return __half2float(a); }
};

template<>
class TypeTraits<float> {
  public:
    __inline__ __device__ static float to_float(const float& a) { return a; }
    __inline__ __device__ static float float_to_scalar(const float& a) { return a; }
};

template<typename scalar_t, int length, int a_stride, int b_stride>
__inline__ __device__ float vector_dot(const half* __restrict__ A, const half* __restrict__ B) {
    using scalar2_t = typename TypeTraits<scalar_t>::scalar2_t;
    scalar_t r_a[length];
    scalar_t r_b[length];
#pragma unroll
    for (int i = 0; i < length; i++) {
        r_a[i] = A[i * a_stride];
        r_b[i] = B[i * b_stride];
    }

    scalar2_t* r_a2 = reinterpret_cast<scalar2_t*>(r_a);
    scalar2_t* r_b2 = reinterpret_cast<scalar2_t*>(r_b);
    scalar2_t c = TypeTraits<half>::mul2(r_a2[0], r_b2[0]);
#pragma unroll
    for (int i = 1; i < length / 2; i++) {
        c = TypeTraits<half>::fma2(r_a2[i], r_b2[i], c);
    }
    return TypeTraits<half>::add_two_to_float(c);
}

// only support col partition
template<typename scalar_t,
         uint32_t row,
         uint32_t s_stride,
         uint32_t d_stride,
         uint32_t tile_col_size,
         uint32_t thread_num>
__inline__ __device__ void load_tile_col_partition(const scalar_t* __restrict__ src,
                                                   scalar_t* __restrict__ dst,
                                                   int partition_idx) {
    constexpr uint32_t threads_per_row = thread_num > row ? thread_num / row : 1;
    constexpr uint32_t rows_per_thread = thread_num > row ? 1 : row / thread_num;
    constexpr uint32_t thread_load_size = tile_col_size / threads_per_row;
    constexpr uint32_t thread_load_bytes_each_time =
      thread_load_size * sizeof(scalar_t) > 16 ? 16 : thread_load_size * sizeof(scalar_t);
    constexpr uint32_t thread_load_times =
      thread_load_size * sizeof(scalar_t) / thread_load_bytes_each_time;

    uint32_t thread_idx = threadIdx.x % thread_num;
    uint32_t thread_row_idx = thread_idx % row; // thread partition by row
    uint32_t thread_col_idx = thread_idx / row;
#pragma unroll
    for (int i = 0; i < rows_per_thread; i++) {
#pragma unroll
        for (int j = 0; j < thread_load_times; j++) {
            uint32_t src_offset = (i * thread_num + thread_row_idx) * s_stride +
                                  partition_idx * tile_col_size +
                                  thread_col_idx * thread_load_size +
                                  j * thread_load_bytes_each_time / sizeof(scalar_t);
            uint32_t dst_offset = (i * thread_num + thread_row_idx) * d_stride +
                                  partition_idx * tile_col_size +
                                  thread_col_idx * thread_load_size +
                                  j * thread_load_bytes_each_time / sizeof(scalar_t);

            cp_async_cg_shared_global<thread_load_bytes_each_time>(dst + dst_offset,
                                                                   src + src_offset);
        }
    }
}

template<typename scalar_t,
         uint32_t col,
         uint32_t s_stride,
         uint32_t d_stride,
         uint32_t tile_row_size,
         uint32_t thread_num>
__inline__ __device__ void load_tile_row_partition(const scalar_t* __restrict__ src,
                                                   scalar_t* __restrict__ dst,
                                                   int partition_idx) {
    constexpr uint32_t thread_load_bytes = tile_row_size * col / thread_num * sizeof(scalar_t);
    constexpr uint32_t thread_load_bytes_each_time =
      thread_load_bytes > 16 ? 16 : thread_load_bytes;
    constexpr uint32_t thread_load_elem_num = thread_load_bytes_each_time / sizeof(scalar_t);
    constexpr uint32_t threads_one_col = col / thread_load_elem_num;
    constexpr uint32_t row_num_one_time = thread_num / threads_one_col;
    constexpr uint32_t thread_load_times = thread_load_bytes / thread_load_bytes_each_time;

    uint32_t thread_idx = threadIdx.x % thread_num;
    uint32_t thread_row_idx = thread_idx / threads_one_col; // thread partition bycol
    uint32_t thread_col_idx = thread_idx % threads_one_col;
#pragma unroll
    for (int i = 0; i < thread_load_times; i++) {
        uint32_t src_offset = partition_idx * tile_row_size * s_stride + (i * row_num_one_time + thread_row_idx) * s_stride +
                              thread_col_idx * thread_load_elem_num;
        uint32_t dst_offset = partition_idx * tile_row_size * d_stride + (i * row_num_one_time + thread_row_idx) * d_stride +
                              thread_col_idx * thread_load_elem_num;
        cp_async_cg_shared_global<thread_load_bytes_each_time>(dst + dst_offset, src + src_offset);
    }
}

template<int M, int N>
class PartitionCal {
  public:
    static constexpr int get_n_warp_size() {
        if constexpr (M < 64 || N < 64) {
            if constexpr (M < N) {
                return 4;
            } else if (M == N) {
                return 2;
            } else {
                return 1;
            }
        } else {
            return 2;
        }
    }
    static constexpr int get_m_warp_size() { return 4 / get_n_warp_size(); }
};

template<typename scalar_t, int M, int N, int K>
__inline__ __device__ void matrix_multiply_gAB_sC(const scalar_t* __restrict__ gA,
                                                  scalar_t* __restrict__ sA,
                                                  const scalar_t* __restrict__ gB,
                                                  scalar_t* __restrict__ sB,
                                                  float* __restrict__ sC,
                                                  uint32_t sC_stride) {
    constexpr int thread_num = 128;
    constexpr int warp_size = 32;
    constexpr int warp_num = 128 / warp_size;
    constexpr int m_warp_size = PartitionCal<M, N>::get_m_warp_size();
    constexpr int n_warp_size = PartitionCal<M, N>::get_n_warp_size();
    const int thread_idx = threadIdx.x;
    const int warp_idx = thread_idx / warp_size;
    const int warp_on_m =
      m_warp_size == n_warp_size ? warp_idx / m_warp_size : warp_idx % m_warp_size;
    const int warp_on_n = warp_idx % n_warp_size;
    const int lane_idx = thread_idx % warp_size;

    constexpr int mma_m = 16;
    constexpr int mma_n = 16;
    constexpr int mma_k = 16;
    constexpr int padding_k = K + 16 / sizeof(scalar_t);
    constexpr int M_Wrap = M / m_warp_size;
    constexpr int N_Wrap = N / n_warp_size;
    constexpr int m_mma_count = M_Wrap / mma_m;
    constexpr int n_mma_count = N_Wrap / mma_n;
    constexpr int k_mma_count = K / mma_k;
    constexpr int pre_load = 1;

    using wmma_row_a = fragment<matrix_a, mma_m, mma_n, mma_k, scalar_t, row_major>;
    ;
    using wmma_row_b = fragment<matrix_b, mma_m, mma_n, mma_k, scalar_t, col_major>;
    using wmma_row_c = fragment<accumulator, mma_m, mma_n, mma_k, float>;

    wmma_row_a a_frag[m_mma_count];
    wmma_row_b b_frag[n_mma_count];
    wmma_row_c c_frag[m_mma_count * n_mma_count];

#pragma unroll
    for (int i = 0; i < m_mma_count * n_mma_count; i++) {
        fill_fragment(c_frag[i], 0);
    }

#pragma unroll
    for (int k = 0; k < k_mma_count; k++) {
        // pre-load tile to smem
        if (k == 0) {
#pragma unroll
            for (int i = 0; i < pre_load; i++) {
                load_tile_col_partition<scalar_t, M, K, padding_k, mma_k, 128>(gA, sA, i);
                load_tile_col_partition<scalar_t, N, K, padding_k, mma_k, 128>(gB, sB, i);
                commit_async_cp_group();
            }
        }
        if (k + pre_load < k_mma_count) {
            load_tile_col_partition<scalar_t, M, K, padding_k, mma_k, 128>(gA, sA, k + pre_load);
            load_tile_col_partition<scalar_t, N, K, padding_k, mma_k, 128>(gB, sB, k + pre_load);
            commit_async_cp_group();
        }

        // loading finishes
        if (k + pre_load < k_mma_count) {
            wait_async_cp_group<pre_load>();
        } else {
            wait_async_cp_group<0>();
        }
        __syncthreads();

#pragma unroll
        for (int i = 0; i < m_mma_count; i++) {
            int a_offset = warp_on_m * M_Wrap * padding_k + i * mma_m * padding_k + k * mma_k;
            load_matrix_sync(a_frag[i], sA + a_offset, padding_k);
        }

#pragma unroll
        for (int i = 0; i < n_mma_count; i++) {
            int b_offset = warp_on_n * N_Wrap * padding_k + i * mma_n * padding_k + k * mma_k;
            load_matrix_sync(b_frag[i], sB + b_offset, padding_k);
        }

#pragma unroll
        for (int i = 0; i < m_mma_count; i++) {
#pragma unroll
            for (int j = 0; j < n_mma_count; j++) {
                mma_sync(
                  c_frag[i * n_mma_count + j], a_frag[i], b_frag[j], c_frag[i * n_mma_count + j]);
            }
        }
    }

    int c_base_offset = warp_on_m * M_Wrap * sC_stride + warp_on_n * N_Wrap;
#pragma unroll
    for (int i = 0; i < m_mma_count; i++) {
#pragma unroll
        for (int j = 0; j < n_mma_count; j++) {
            int c_offset = c_base_offset + i * mma_m * sC_stride + j * mma_n;
            store_matrix_sync(sC + c_offset, c_frag[i * n_mma_count + j], sC_stride, mem_row_major);
        }
    }
    __syncthreads();
}

template<typename scalar_t, int M, int N, int K>
__inline__ __device__ void matrix_multiply_sA_gBC(const scalar_t* __restrict__ sA,
                                                  const scalar_t* __restrict__ gB,
                                                  scalar_t* __restrict__ sB,
                                                  half* __restrict__ gC,
                                                  uint32_t sA_stride,
                                                  uint32_t sC_stride) {
    constexpr int warp_size = 32;
    constexpr int m_warp_size = PartitionCal<M, N>::get_m_warp_size();
    constexpr int n_warp_size = PartitionCal<M, N>::get_n_warp_size();
    const int thread_idx = threadIdx.x;
    const int warp_idx = thread_idx / warp_size;
    const int warp_on_m =
      n_warp_size == m_warp_size ? warp_idx / m_warp_size : warp_idx % m_warp_size;
    const int warp_on_n = warp_idx % n_warp_size;

    constexpr int mma_m = 16;
    constexpr int mma_n = 16;
    constexpr int mma_k = 16;
    constexpr int padding_n = N + 16 / sizeof(scalar_t);
    constexpr int M_Wrap = M / m_warp_size;
    constexpr int N_Wrap = N / n_warp_size;
    constexpr int m_mma_count = M_Wrap / mma_m;
    constexpr int n_mma_count = N_Wrap / mma_n;
    constexpr int k_mma_count = K / mma_k;
    constexpr int pre_load = 1;

    using wmma_row_a = fragment<matrix_a, mma_m, mma_n, mma_k, scalar_t, row_major>;
    ;
    using wmma_row_b = fragment<matrix_b, mma_m, mma_n, mma_k, scalar_t, row_major>;
    using wmma_row_c = fragment<accumulator, mma_m, mma_n, mma_k, scalar_t>;
    wmma_row_a a_frag[m_mma_count];
    wmma_row_b b_frag[n_mma_count];
    wmma_row_c c_frag[m_mma_count * n_mma_count];

    for (int i = 0; i < m_mma_count * n_mma_count; i++) {
        fill_fragment(c_frag[i], 0);
    }

#pragma unroll
    for (int k = 0; k < k_mma_count; k++) {
        // pre-load tile to smem
        if (k == 0) {
#pragma unroll
            for (int i = 0; i < pre_load; i++) {
                load_tile_row_partition<scalar_t, N, N, padding_n, mma_k, 128>(gB, sB, i);
                commit_async_cp_group();
            }
        }
        if (k + pre_load < k_mma_count) {
            load_tile_row_partition<scalar_t, N, N, padding_n, mma_k, 128>(gB, sB, k + pre_load);
            commit_async_cp_group();
        }

        // loading finishes
        if (k + pre_load < k_mma_count) {
            wait_async_cp_group<pre_load>();
        } else {
            wait_async_cp_group<0>();
        }
        __syncthreads();

#pragma unroll
        for (int i = 0; i < m_mma_count; i++) {
            int a_offset = warp_on_m * M_Wrap * sA_stride + i * mma_m * sA_stride + k * mma_k;
            load_matrix_sync(a_frag[i], sA + a_offset, sA_stride);
        }

#pragma unroll
        for (int i = 0; i < n_mma_count; i++) {
            int b_offset = warp_on_n * N_Wrap + k * mma_k * padding_n + i * mma_n;
            load_matrix_sync(b_frag[i], sB + b_offset, padding_n);
        }

        // load to registry and compute
#pragma unroll
        for (int i = 0; i < m_mma_count; i++) {
#pragma unroll
            for (int j = 0; j < n_mma_count; j++) {
                mma_sync(
                  c_frag[i * n_mma_count + j], a_frag[i], b_frag[j], c_frag[i * n_mma_count + j]);
            }
        }
    }

    int c_base_offset = warp_on_m * M_Wrap * sC_stride + warp_on_n * N_Wrap;
#pragma unroll
    for (int i = 0; i < m_mma_count; i++) {
#pragma unroll
        for (int j = 0; j < n_mma_count; j++) {
            int c_offset = c_base_offset + i * mma_m * sC_stride + j * mma_n;
            store_matrix_sync(gC + c_offset, c_frag[i * n_mma_count + j], sC_stride, mem_row_major);
        }
    }
    __syncthreads();
}

template<int length>
__inline__ __device__ float warp_vector_scale(float* __restrict__ val, float scale) {
    constexpr int warp_size = 32;
    int lane_id = threadIdx.x % warp_size;
#pragma unroll
    for (int i = lane_id; i < length; i += warp_size) {
        val[i] *= scale;
    }
    __syncwarp();
}

template<int N>
__inline__ __device__ float warp_vector_sum(const float* __restrict__ val) {
    constexpr int warp_size = 32;
    int thread_idx = threadIdx.x % warp_size;
    float sum = 0;
#pragma unroll
    for (int i = thread_idx; i < N; i += warp_size) {
        sum += val[i];
    }
    for (int i = warp_size / 2; i >= 1; i /= 2) {
        sum += __shfl_xor_sync(uint32_t(-1), sum, i);
    }
    return sum;
}

template<int vector_length>
__inline__ __device__ float warp_vector_max(const float* __restrict__ val) {
    constexpr int warp_size = 32;
    int thread_idx = threadIdx.x % warp_size;
    float max_val = -FLT_MAX;
#pragma unroll
    for (int i = thread_idx; i < vector_length; i += warp_size) {
        max_val = fmaxf(max_val, val[i]);
    }
    for (int i = warp_size / 2; i >= 1; i /= 2) {
        max_val = fmaxf(max_val, __shfl_xor_sync(uint32_t(-1), max_val, i));
    }
    return max_val;
}

template<typename scalar_t, int N>
__inline__ __device__ float warp_cal_exp(float* __restrict__ val,
                                         scalar_t* __restrict__ cpy,
                                         const float max) {
    constexpr int warp_size = 32;
    int thread_idx = threadIdx.x % warp_size;
    for (int i = thread_idx; i < N; i += warp_size) {
        float exp = __expf(val[i] - max);
        val[i] = exp;
        cpy[i] = TypeTraits<scalar_t>::float_to_scalar(exp);
    }
    __syncwarp();
}

template<typename scalar_t, int N>
__inline__ __device__ void warp_vector_merge(scalar_t* a,
                                             scalar_t* b,
                                             float scale_a,
                                             float scale_b) {
    constexpr int warp_size = 32;
    constexpr int each_thread_merge = (N / 2) / warp_size;
    int lane_idx = threadIdx.x % warp_size;
    using scalar2_t = typename TypeTraits<scalar_t>::scalar2_t;
    scalar2_t scale_a2 = TypeTraits<scalar_t>::float_to_scalar2(scale_a);
    scalar2_t scale_b2 = TypeTraits<scalar_t>::float_to_scalar2(scale_b);
    scalar2_t* a_2 = reinterpret_cast<scalar2_t*>(a);
    scalar2_t* b_2 = reinterpret_cast<scalar2_t*>(b);

#pragma unroll
    for (int i = 0; i < each_thread_merge; i++) {
        int offset = i * warp_size + lane_idx;
        a_2[offset] = TypeTraits<scalar_t>::fma2(
          b_2[offset], scale_b2, TypeTraits<scalar_t>::mul2(a_2[offset], scale_a2));
    }
    __syncwarp();
}

template<typename scalar_t, int N, int K>
__inline__ __device__ void warp_vect_mul_raw_major_matrix(const scalar_t* vec,
                                                          const scalar_t* mat,
                                                          scalar_t* shared_mat,
                                                          float* result,
                                                          float scale) {
    using vector_t = typename TypeTraits<scalar_t>::vector_t;

    constexpr int warp_size = 32;
    constexpr int load_size = 64;
    constexpr int load_times = K / load_size;
    constexpr int padding_k = K + 16 / sizeof(scalar_t);
    constexpr int pre_load = 1;
    constexpr int vector_size = sizeof(vector_t) / sizeof(scalar_t);
    constexpr int vector_count = load_size / (16 / sizeof(scalar_t));
    constexpr int n_pre_thread = N / warp_size;
    const int lane_idx = threadIdx.x % warp_size;

#pragma unroll
    for (int k = 0; k < load_times; k++) {
        // pre-load tile to smem
        if (k == 0) {

#pragma unroll
            for (int i = 0; i < pre_load; i++) {
                load_tile_col_partition<scalar_t, N, K, padding_k, load_size, 32>(
                  mat, shared_mat, i);
                commit_async_cp_group();
            }
        }
        if (k + pre_load < load_times) {
            load_tile_col_partition<scalar_t, N, K, padding_k, load_size, 32>(
              mat, shared_mat, k + pre_load);
            commit_async_cp_group();
        }

        // loading finishes
        if (k + pre_load < load_times) {
            wait_async_cp_group<pre_load>();
        } else {
            wait_async_cp_group<0>();
        }

        __syncwarp();

#pragma unroll
        for (int i = 0; i < vector_count; i++) {
            vector_t vec_reg;
            vector_t mat_reg;
            vec_reg = reinterpret_cast<const vector_t*>(vec)[k * vector_count + i];
#pragma unroll
            for (int j = 0; j < n_pre_thread; j++) {
                mat_reg = reinterpret_cast<const vector_t*>(
                  shared_mat + (j * warp_size + lane_idx) * padding_k)[k * vector_count + i];
                result[j] += vec_reg.dot(mat_reg);
            }
        }
    }

#pragma unroll
    for (int i = 0; i < n_pre_thread; i++) {
        result[i] = result[i] * scale;
    }
    __syncwarp();
}

template<typename scalar_t, int N, int K>
__inline__ __device__ void warp_vect_mul_raw_major_matrix_v2(const scalar_t* vec,
                                                          const scalar_t* mat,
                                                          scalar_t* shared_mat,
                                                          float* result,
                                                          float scale) {
    using vector_t = typename TypeTraits<scalar_t>::vector_t;

    constexpr int warp_size = 32;
    constexpr int n_load_elems = warp_size;
    constexpr int n_load_times = N  / warp_size;
    constexpr int k_load_size = 128;
    constexpr int k_load_elems = k_load_size / sizeof(scalar_t);
    constexpr int k_load_times = K / k_load_elems;
    constexpr int load_times = n_load_times * k_load_times;
    constexpr int vector_size = sizeof(vector_t) / sizeof(scalar_t);
    constexpr int vector_count = k_load_elems / vector_size;


    constexpr int padding_k = K + 16 / sizeof(scalar_t);
    constexpr int pre_load = 1;
    const int lane_idx = threadIdx.x % warp_size;

#pragma unroll
    for (int l = 0; l < load_times; l++) {
        // pre-load tile to smem
        int k = l / n_load_times;
        int n = l % n_load_times;

        if (l == 0) {

#pragma unroll
            for (int i = 0; i < pre_load; i++) {
                load_tile_row_partition<scalar_t, k_load_elems, K, padding_k, n_load_elems, 32>(mat + k * k_load_elems, shared_mat + k * k_load_elems, n);
                commit_async_cp_group();
            }
        }
        if (l + pre_load < load_times) {
            int k_next = (l + pre_load) / n_load_times;
            int n_next = (l + pre_load) % n_load_times;
            load_tile_row_partition<scalar_t, k_load_elems, K, padding_k, n_load_elems, 32>(mat + k_next * k_load_elems, shared_mat + k_next * k_load_elems, n_next);
            commit_async_cp_group();
        }

        // loading finishes
        if (l + pre_load < load_times) {
            wait_async_cp_group<pre_load>();
        } else {
            wait_async_cp_group<0>();
        }

        __syncwarp();

#pragma unroll
        for (int i = 0; i < vector_count; i++) {
            vector_t vec_reg;
            vector_t mat_reg;
            vec_reg = reinterpret_cast<const vector_t*>(vec + k * k_load_elems)[i];
            mat_reg = reinterpret_cast<const vector_t*>(
              shared_mat + (n * warp_size + lane_idx) * padding_k + k * k_load_elems)[i];
            result[n] += vec_reg.dot(mat_reg);
        }
    }

#pragma unroll
    for (int i = 0; i < n_load_times; i++) {
        result[i] = result[i] * scale;
    }
    __syncwarp();
}

template<typename scalar_t, int N, int K>
__inline__ __device__ void warp_vect_mul_col_major_matrix(const scalar_t *vec, const scalar_t *mat, scalar_t *shared_mat, scalar_t *result, scalar_t scale)
{
    using vector_t = typename TypeTraits<scalar_t>::vector_t;

    constexpr int warp_size = 32;
    constexpr int load_size = 8;
    constexpr int load_times = K / load_size;
    constexpr int padding_n = N + 16 / sizeof(scalar_t);
    constexpr int pre_load = 1;
    constexpr int vector_length = sizeof(vector_t) / sizeof(scalar_t);
    constexpr int vector_count = load_size / (16 / sizeof(scalar_t));
    constexpr int n_pre_thread = N / warp_size;
    const int lane_idx = threadIdx.x % warp_size;

    float reg_result[n_pre_thread] = {0};

#pragma unroll
    for (int k = 0; k < load_times; k++) {
        // pre-load tile to smem
        if (k == 0) {
#pragma unroll
            for (int i = 0; i < pre_load; i++) {
                load_tile_row_partition<scalar_t, N, N, padding_n, load_size, 32>(
                  mat, shared_mat, i);
                commit_async_cp_group();
            }
        }
        if (k + pre_load < load_times) {
            load_tile_row_partition<scalar_t, N, N, padding_n, load_size, 32>(
              mat, shared_mat, k + pre_load);
            commit_async_cp_group();
        }

        // loading finishes
        if (k + pre_load < load_times) {
            wait_async_cp_group<pre_load>();
        } else {
            wait_async_cp_group<0>();
        }

        __syncwarp();

#pragma unroll
        for (int i = 0; i < vector_count; i++) {
            vector_t vec_reg;
            vector_t mat_reg;
            vec_reg = reinterpret_cast<const vector_t*>(vec)[k * vector_count + i];
#pragma unroll
            for (int j = 0; j < n_pre_thread; j++) {
#pragma unroll
                for(int l = 0; l < vector_length; l++) {
                    mat_reg.elements[l] = shared_mat[(k * load_size + i * vector_length + l) * padding_n + j * warp_size + lane_idx];
                }
                reg_result[j] += vec_reg.dot(mat_reg);
            }
        }
    }
#pragma unroll
    for (int i = 0; i < n_pre_thread; i++)
    {
        result[i * warp_size + lane_idx] = result[i * warp_size + lane_idx] * scale + TypeTraits<scalar_t>::float_to_scalar(reg_result[i]);
    }
    __syncwarp();

    //    if (blockIdx.y == 0) {
    //        print_matrix<scalar_t, 1, K>(0, vec, K);
    //        print_matrix<scalar_t, K, N>(0, shared_mat, padding_n);
    //        print_matrix<scalar_t, K, N>(0, mat, N);
    //        print_matrix<scalar_t, 1, N>(0, result, padding_n);
    //    }
}

template<typename scalar_t, int seq_length, int chunk_size, int head_dim>
__global__ void attn_chunk_first_kernel(
  const scalar_t* __restrict__ query, // [num_head, num_seqs, head_size]
  void** __restrict__ keys,           // chunk_num<[num_head, chunk_size, head_dim]>
  void** __restrict__ values,         // chunk_num<[num_head, chunk_size, head_dim]>
  void** __restrict__ qkv_results,    // chunk_num<[num_head, num_seqs , head_dim]>
  const int* __restrict__ starts,
  const int* __restrict__ ends,
  void** __restrict__ score_max_list, // chunk_num<[num_head, num_seqs]>
  void** __restrict__ score_sum_list, // chunk_num<[num_head, num_seqs]>
  uint32_t q_head_stride,
  uint32_t kv_head_stride,
  float dim_scale) {

    constexpr uint32_t padded_chunk_size = chunk_size + 16 / sizeof(half);
    constexpr uint32_t padded_head_dim = head_dim + 16 / sizeof(half);
    constexpr uint32_t thread_num = 128;
    constexpr uint32_t warp_size = 32;
    constexpr uint32_t wrap_num = thread_num / warp_size;

    const uint32_t head_idx = blockIdx.x;
    const uint32_t chunk_idx = blockIdx.y;

    const uint32_t thread_idx = threadIdx.x;

    const uint32_t warp_idx = thread_idx / warp_size;
    const uint32_t lane_idx = thread_idx % warp_size;

    const int start = starts[chunk_idx];

    const uint32_t q_row_offset = head_idx * q_head_stride;
    const uint32_t k_row_offset = head_idx * kv_head_stride;
    const uint32_t v_row_offset = head_idx * kv_head_stride;
    const uint32_t score_max_row_offset = head_idx * seq_length;

    const scalar_t* q = query + q_row_offset;
    auto* __restrict__ k = reinterpret_cast<scalar_t*>(keys[chunk_idx]) + k_row_offset;
    auto* __restrict__ v = reinterpret_cast<scalar_t*>(values[chunk_idx]) + v_row_offset;
    auto* __restrict__ qkv_result =
      reinterpret_cast<scalar_t*>(qkv_results[chunk_idx]) + q_row_offset;
    auto* score_max = static_cast<float*>(score_max_list[chunk_idx]) + score_max_row_offset;
    auto* score_sum = static_cast<float*>(score_sum_list[chunk_idx]) + score_max_row_offset;

    extern __shared__ char smem[];
    scalar_t* shared_q = reinterpret_cast<scalar_t*>(smem);
    scalar_t* shared_k = shared_q + seq_length * padded_head_dim;
    float* shared_score = reinterpret_cast<float*>(shared_k + chunk_size * padded_head_dim);
    //    __shared__ scalar_t shared_q[seq_length_ * padded_head_dim];
    //    __shared__ scalar_t shared_k[chunk_size * padded_head_dim];
    //    __shared__ float shared_score[seq_length_ * padded_chunk_size];
    scalar_t* shared_v = shared_k;
    scalar_t* shared_half_score = shared_q;
    float* shared_output = reinterpret_cast<float*>(smem);
    // share v with k
    matrix_multiply_gAB_sC<scalar_t, seq_length, chunk_size, head_dim>(
      q, shared_q, k, shared_k, shared_score, padded_chunk_size);

    //        print_matrix<half , seq_length, head_dim>("q",5, 4, q, head_dim);
    //        print_matrix<half , chunk_size, head_dim>("k",5, 4, k, head_dim);
    //        print_matrix<float, seq_length, chunk_size>("score",5, 4, shared_score,
    //        padded_chunk_size);

    // compute
#pragma unroll
    for (int i = warp_idx; i < seq_length; i += wrap_num) {
        warp_vector_scale<chunk_size>(shared_score + i * padded_chunk_size, dim_scale);
        float seq_score_max = warp_vector_max<chunk_size>(shared_score + i * padded_chunk_size);
        warp_cal_exp<scalar_t, chunk_size>(shared_score + i * padded_chunk_size,
                                           shared_half_score + i * padded_chunk_size,
                                           seq_score_max);
        float seq_score_sum = warp_vector_sum<chunk_size>(shared_score + i * padded_chunk_size);
        if (lane_idx == 0) {
            score_max[i] = seq_score_max;
            score_sum[i] = seq_score_sum;
        }
    }

    __syncthreads();

    matrix_multiply_sA_gBC<scalar_t, seq_length, head_dim, chunk_size>(
      shared_half_score, v, shared_v, qkv_result, padded_chunk_size, head_dim);
}

template<typename scalar_t, int chunk_size, int head_dim>
__global__ void attn_seq_first_kernel(
  const scalar_t* __restrict__ query, // [num_head, num_seqs, head_dim]
  scalar_t* __restrict__ output,      //  [num_head, num_seqs, head_dim]
  void** __restrict__ keys,           // chunk_num<[num_head, chunk_size, head_dim]>
  void** __restrict__ values,         // chunk_num<[num_head, chunk_size, head_dim]>
  void** __restrict__ qkv_results,    // chunk_num<[num_head, num_seqs, head_dim]>
  void** __restrict__ score_max_list, // chunk_num<[num_head, num_seqs]>
  void** __restrict__ score_sum_list, // chunk_num<[num_head, num_seqs]>
  const int* __restrict__ seqs_chunk_mapping,
  const int* __restrict__ seqs_length,
  uint32_t q_head_stride,
  uint32_t kv_head_stride,
  uint32_t chunk_mapping_stride,
  float dim_scale) {
    constexpr uint32_t thread_num = 128;
    constexpr uint32_t warp_size = 32;
    constexpr uint32_t warp_num = thread_num / warp_size;
    constexpr uint32_t tokens_per_thread = chunk_size / warp_size;
    constexpr uint32_t dim_per_thread = head_dim / warp_size;
    constexpr uint32_t load_vec_length = 16 / sizeof(half);
    constexpr uint32_t padding_head_dim = head_dim + 16 / sizeof(scalar_t);

    static_assert(chunk_size % warp_size == 0, "chunk_size must be divided by warp_size");
    static_assert(head_dim % warp_size == 0, "head_dim must be divided by warp_size");

    const uint32_t head_idx = blockIdx.x;
    const uint32_t seq_idx = blockIdx.y;
    const uint32_t seq_num = gridDim.y;

    const uint32_t thread_idx = threadIdx.x;
    const uint32_t wrap_idx = thread_idx / warp_size;
    const uint32_t lane_idx = thread_idx % warp_size;

    const uint32_t seq_length = seqs_length[seq_idx];
    const uint32_t chunk_num = (seq_length + chunk_size - 1) / chunk_size;

    const uint32_t q_row_offset = head_idx * q_head_stride + seq_idx * head_dim;
    const uint32_t kv_row_offset = head_idx * kv_head_stride;
    const uint32_t score_max_offset = head_idx * seq_num + seq_idx;
    const uint32_t seq_mapping_seq_offset = seq_idx * chunk_mapping_stride;

    const scalar_t* q = query + q_row_offset;
    scalar_t* output_seq = output + q_row_offset;
    const int* seq_mapping = seqs_chunk_mapping + seq_mapping_seq_offset;
    using vector_t = typename TypeTraits<scalar_t>::vector_t;
    using scalar2_t = typename TypeTraits<scalar_t>::scalar2_t;

    __shared__ scalar_t shared_q[head_dim];
    __shared__ scalar_t shared_output[warp_num * padding_head_dim];
    __shared__ scalar_t shared_score[warp_num * chunk_size];
    __shared__ float shared_score_max[warp_num];
    __shared__ float shared_score_sum[warp_num];
    extern __shared__ char smem[];
    scalar_t* shared_kv =
      reinterpret_cast<scalar_t*>(smem) + wrap_idx * chunk_size * padding_head_dim;
    // load shared q
#pragma unroll
    for (int i = thread_idx; i < head_dim; i += thread_num) {
        shared_q[i] = q[i];
    }
#pragma unroll
    for (int i = 0; i < dim_per_thread; i++) {
        shared_output[wrap_idx * padding_head_dim + lane_idx + i * warp_size] = 0;
    }
    __syncthreads();

    // each warp compute one chunk
    // warp 0 -> chunk: 0 4 8 12...
    // warp 1 -> chunk: 1 5 9 13...
    float score_max = -FLT_MAX;
    float score_sum = 0;
    for (int i = wrap_idx; i < chunk_num; i += warp_num) {
        const int chunk_idx = seq_mapping[i];
        scalar_t* shared_output_chunk = shared_output + wrap_idx * padding_head_dim;
        scalar_t* shared_chunk_score = shared_score + wrap_idx * chunk_size;

        // merge existing result
        if (qkv_results[chunk_idx] != nullptr) {
            float cached_max = static_cast<float*>(score_max_list[chunk_idx])[score_max_offset];
            float cached_score_sum =
              static_cast<float*>(score_sum_list[chunk_idx])[score_max_offset];
            scalar_t* cached_qkv_result =
              static_cast<scalar_t*>(qkv_results[chunk_idx]) + q_row_offset;

            float new_score_max = fmax(score_max, cached_max);
            //            float cached_scale = expf(cached_max - new_score_max);
            float cached_scale =
              __shfl_sync(0xffffffff, lane_idx == 0 ? expf(cached_max - new_score_max) : 0, 0);
            //            float scale = expf(score_max_ - new_score_max);
            float scale =
              __shfl_sync(0xffffffff, lane_idx == 0 ? expf(score_max - new_score_max) : 0, 0);
            score_max = new_score_max;
            score_sum = cached_score_sum * cached_scale + score_sum * scale;
            //            print_matrix<scalar_t, 1, head_dim>(0, cached_qkv_result, head_dim);
            warp_vector_merge<scalar_t, head_dim>(
              shared_output_chunk, cached_qkv_result, scale, cached_scale);
            //            print_matrix<scalar_t, 1, head_dim>(0, shared_output_chunk, head_dim);
            continue;
        }

        const scalar_t* __restrict__ g_k =
          reinterpret_cast<scalar_t*>(keys[chunk_idx]) + kv_row_offset; // [chunk_size, head_dim]
        const scalar_t* __restrict__ g_v =
          reinterpret_cast<scalar_t*>(values[chunk_idx]) + kv_row_offset; // [chunk_size, head_dim]

        float chunk_score[tokens_per_thread] = { 0 };
        // #pragma unroll
        //         for (int k = 0; k < head_dim / load_vec_length; k += 1) {
        ////            k = (k + lane_idx) & ((head_dim / load_vec_length) - 1);
        //            vector_t q_vec = reinterpret_cast<const vector_t*>(shared_q)[k];
        //            // compute two token each thread
        // #pragma unroll
        //            for (int j = 0; j < tokens_per_thread; j++) {
        //                chunk_score[j] += q_vec.dot(reinterpret_cast<const vector_t*>(
        //                                    g_k + (j * warp_size + lane_idx) * head_dim)[k]) *
        //                                  dim_scale;
        //            }
        //        }
//        warp_vect_mul_raw_major_matrix<scalar_t, chunk_size, head_dim>(
//          shared_q, g_k, shared_kv, chunk_score, dim_scale);
        warp_vect_mul_raw_major_matrix_v2<scalar_t, chunk_size, head_dim>(
          shared_q, g_k, shared_kv, chunk_score, dim_scale);


        float chunk_score_max = score_max;
#pragma unroll
        for (int j = 0; j < tokens_per_thread; j++) {
            chunk_score_max = fmaxf(chunk_score_max, chunk_score[j]);
        }
        //        __syncwarp();

        // warp reduce max
#pragma unroll
        for (int mask = warp_size / 2; mask >= 1; mask /= 2) {
            chunk_score_max =
              fmaxf(chunk_score_max, __shfl_xor_sync(uint32_t(-1), chunk_score_max, mask));
        }

        // compute score and store to smem
        float chunk_score_sum = 0;
#pragma unroll
        for (int j = 0; j < tokens_per_thread; j++) {
            chunk_score[j] = expf(chunk_score[j] - chunk_score_max);
            shared_chunk_score[j * warp_size + lane_idx] = __float2half(chunk_score[j]);
            chunk_score_sum += chunk_score[j];
        }
        __syncwarp();

        // warp reduce sum
#pragma unroll
        for (int mask = warp_size / 2; mask >= 1; mask /= 2) {
            chunk_score_sum += __shfl_xor_sync(uint32_t(-1), chunk_score_sum, mask);
        }

        float score_scale =
          __shfl_sync(0xffffffff, lane_idx == 0 ? expf(score_max - chunk_score_max) : 0, 0);
        score_max = chunk_score_max;
        score_sum = score_sum * score_scale + chunk_score_sum;
        // compute v
        // #pragma unroll
        //         for (int j = 0; j < dim_per_thread; j++) {
        //             float chunk_output_dim = 0;
        // #pragma unroll
        //             for (int k = 0; k < chunk_size; k += load_vec_length) {
        //                 chunk_output_dim += vector_dot<half, load_vec_length, 1, head_dim>(
        //                   shared_chunk_score + k, g_v + k * head_dim + j * warp_size + lane_idx);
        //             }
        //             shared_output_chunk[j * warp_size + lane_idx] =
        //               __float2half(chunk_output_dim) +
        //               __float2half(score_scale) * shared_output_chunk[j * warp_size + lane_idx];
        //         }
        warp_vect_mul_col_major_matrix<scalar_t, head_dim, chunk_size>(
          shared_chunk_score, g_v, shared_kv, shared_output_chunk, score_scale);
        //        __syncwarp();
    }
    if (lane_idx == 0) {
        shared_score_max[wrap_idx] = score_max;
        shared_score_sum[wrap_idx] = score_sum;
    }
    __syncthreads();

    float scale[warp_num];
    float div = 0;

    // only one thread compute the scale in the warp
    if (lane_idx == 0) {
        score_max = shared_score_max[0];
        score_sum = 0;
#pragma unroll
        for (int i = 1; i < warp_num; i++) {
            score_max = fmaxf(score_max, shared_score_max[i]);
        }
#pragma unroll
        for (int i = 0; i < warp_num; i++) {
            scale[i] = expf(shared_score_max[i] - score_max);
            score_sum += shared_score_sum[i] * scale[i];
        }
        div = __fdividef(1.f, score_sum + 1e-6f);
    }
    __syncwarp();

    // sync with other threads
#pragma unroll
    for (int i = 0; i < warp_num; i++) {
        scale[i] = __shfl_sync(0xffffffff, scale[i], 0);
    }
    div = __shfl_sync(0xffffffff, div, 0);

    // compute output
#pragma unroll
    for (int i = thread_idx; i < head_dim; i += thread_num) {
        float output_dim = 0;
#pragma unroll
        for (int j = 0; j < warp_num; j++) {
            output_dim +=
              TypeTraits<scalar_t>::to_float(shared_output[j * padding_head_dim + i]) * scale[j];
        }
        output_seq[i] = TypeTraits<scalar_t>::float_to_scalar(output_dim * div);
    }
}

__host__ void GPUKernel::attn_chunks_first(
  torch::Tensor& query,                    // [num_head, num_seqs, head_dim]
  std::vector<torch::Tensor>& keys,        // chunk_num<[num_head, chunk_size, head_dim]>
  std::vector<torch::Tensor>& values,      // chunk_num<[num_head, chunk_size, head_dim]>
  std::vector<torch::Tensor>& qkv_results, // chunk_num<[num_seqs,num_head, head_dim]>
  torch::Tensor& start,
  torch::Tensor& end,
  std::vector<torch::Tensor>& score_max,
  std::vector<torch::Tensor>& score_sum,
  Trace* trace) {
    uint32_t num_heads = query.size(0);
    uint32_t num_seqs = query.size(1);
    uint32_t head_dim = query.size(2);
    uint32_t chunk_num = values.size();

    //    TORCH_CHECK(head_dim == 128, "head dim must be 128");
    //    TORCH_CHECK(num_seqs % 16 == 0, "num seqs must be multiple of 16");
    //    TORCH_CHECK(!keys.empty() && !values.empty() && !qkv_results.empty(),
    //                "keys, values and qkv_results must not be empty");
    //    TORCH_CHECK(keys.size() == values.size() && keys.size() == qkv_results.size(),
    //                "keys, values and qkv_results must have same size");
    //    TORCH_CHECK(score_max_.size() == score_sum_.size(),
    //                "score_max_ and score_sum_ must have same size");

    uint32_t chunk_size = keys[0].size(1);
    uint32_t q_head_stride = query.stride(0);
    uint32_t kv_head_stride = keys[0].stride(0);

    //    TORCH_CHECK(chunk_size % 16 == 0, "chunk size must be multiple of 16");

    float scale = 1.f / sqrtf(static_cast<float>(head_dim));

    size_t query_shared_mem_size = query.nbytes() / num_heads + 16 * num_seqs; // 16 is for padding
    size_t keys_shared_mem_size = keys[0].nbytes() / num_heads + 16 * chunk_size;
    size_t score_shared_mem_size = (chunk_size + 16) * num_seqs * sizeof(float);
    size_t shared_mem_size = query_shared_mem_size + keys_shared_mem_size + score_shared_mem_size;
    //    std::cout << "query per sm size: " << query_shared_mem_size << std::endl;
    //    std::cout << "keys per sm size " << keys_shared_mem_size << std::endl;
    //    std::cout << "score per sm size: " << score_shared_mem_size << std::endl;
    //    std::cout << "shared_mem_size: " << shared_mem_size << std::endl;

    dim3 grid(num_heads, chunk_num);
    const cudaStream_t stream = at::cuda::getCurrentCUDAStream();
    construct_tensor_ptr_list(keys, values, qkv_results, score_max, score_sum);
    cudaMemset(score_sum_, 0, sizeof(float) * num_heads * num_seqs);

#define CALL_ATTN_CHUNK_KERNEL_FUNCTION(scalar_t, batch_size, chunk_size, head_dim)                \
    {                                                                                              \
        if (shared_mem_size >= 48 * 1024) {                                                        \
            C10_CUDA_CHECK(cudaFuncSetAttribute(                                                   \
              attn_chunk_first_kernel<scalar_t, batch_size, chunk_size, head_dim>,                 \
              cudaFuncAttributeMaxDynamicSharedMemorySize,                                         \
              shared_mem_size));                                                                   \
        }                                                                                          \
        attn_chunk_first_kernel<scalar_t, batch_size, chunk_size, head_dim>                        \
          <<<grid, 128, shared_mem_size, stream>>>(query_ptr,                                      \
                                                   keys_ptr,                                       \
                                                   values_ptr,                                     \
                                                   qkv_results_ptr,                                \
                                                   start_ptr,                                      \
                                                   end_ptr,                                        \
                                                   score_max_ptr,                                  \
                                                   score_sum_ptr,                                  \
                                                   q_head_stride,                                  \
                                                   kv_head_stride,                                 \
                                                   scale);                                         \
    }

    std::chrono::high_resolution_clock::time_point start_t, end_t;
    if (trace != nullptr && trace->record_kernel_t) {
        cudaDeviceSynchronize();
        start_t = std::chrono::high_resolution_clock::now();
    }
    if (query.dtype() == at::ScalarType::Float) {
        throw std::runtime_error("Unsupported data type: float");
    } else if (query.dtype() == at::ScalarType::Half) {
        auto query_ptr = reinterpret_cast<half*>(query.data_ptr());
        auto keys_ptr = this->key_list_;
        auto values_ptr = this->value_list_;
        auto qkv_results_ptr = this->qkv_result_list;
        auto start_ptr = reinterpret_cast<int*>(start.data_ptr());
        auto end_ptr = reinterpret_cast<int*>(end.data_ptr());
        auto score_max_ptr = this->score_max_list;
        auto score_sum_ptr = this->score_sum_list;

        if (num_seqs == 16 && chunk_size == 64)
            CALL_ATTN_CHUNK_KERNEL_FUNCTION(half, 16, 64, 128)
        else if (num_seqs == 32 && chunk_size == 32)
            CALL_ATTN_CHUNK_KERNEL_FUNCTION(half, 32, 32, 128)
        else if (num_seqs == 32 && chunk_size == 64)
            CALL_ATTN_CHUNK_KERNEL_FUNCTION(half, 32, 64, 128)
        else if (num_seqs == 48 && chunk_size == 64)
            CALL_ATTN_CHUNK_KERNEL_FUNCTION(half, 48, 64, 128)
        else if (num_seqs == 64 && chunk_size == 64)
            CALL_ATTN_CHUNK_KERNEL_FUNCTION(half, 64, 64, 128)
        else if (num_seqs == 96 && chunk_size == 64)
            CALL_ATTN_CHUNK_KERNEL_FUNCTION(half, 96, 64, 128)
        else if (num_seqs == 128 && chunk_size == 64)
            CALL_ATTN_CHUNK_KERNEL_FUNCTION(half, 128, 64, 128)
        else if (num_seqs == 64 && chunk_size == 32)
            CALL_ATTN_CHUNK_KERNEL_FUNCTION(half, 64, 32, 128)
        else if (num_seqs == 64 && chunk_size == 128)
            CALL_ATTN_CHUNK_KERNEL_FUNCTION(half, 64, 128, 128)
        else
            throw std::runtime_error("Unsupported chunk_size and batch_size");
    } else if (query.dtype() == at::ScalarType::BFloat16) {
        throw std::runtime_error("Unsupported data type: bfloat16");
    } else {
        TORCH_CHECK(false, "Unsupported data type: ", query.dtype());
    }
    if (trace != nullptr && trace->record_kernel_t) {
        cudaDeviceSynchronize();
        end_t = std::chrono::high_resolution_clock::now();
        auto duration = std::chrono::duration_cast<std::chrono::microseconds>(end_t - start_t);
        trace->chunk_kernel_t = duration.count();
        // std::cout << "chunk kernel time: " << duration.count() << std::endl;
    }
    C10_CUDA_KERNEL_LAUNCH_CHECK();
}

void GPUKernel::attn_seqs_first(
  const torch::Tensor& query,                    // [num_head, num_seqs, head_dim]
  torch::Tensor& output,                         // [num_head, num_seqs, head_dim]
  const std::vector<torch::Tensor>& keys,        // chunk_num<[num_head, chunk_size, head_dim]>
  const std::vector<torch::Tensor>& values,      // chunk_num<[num_head, chunk_size, head_dim]>
  const std::vector<torch::Tensor>& qkv_results, // chunk_num<[num_head, num_seqs, head_dim]>
  const std::vector<torch::Tensor>& score_max,   // chunk_num<[num_head, num_seqs]>
  const std::vector<torch::Tensor>& score_sum,   // chunk_num<[num_head, num_seqs]>
  const std::vector<std::vector<int>>& seq_chunk_mapping,
  const std::vector<int>& seq_length,
  Trace* trace) {
    uint32_t num_heads = query.size(0);
    uint32_t num_seqs = query.size(1);
    uint32_t head_dim = query.size(2);
    //    TORCH_CHECK(head_dim == 128, "head dim must be 128");
    //    TORCH_CHECK(num_seqs % 16 == 0, "num seqs must be multiple of 16");

    uint32_t chunk_size = keys[0].size(1);
    uint32_t chunk_num = keys.size();
    uint32_t q_head_stride = query.stride(0);
    uint32_t kv_head_stride = keys[0].stride(0);

    //    TORCH_CHECK(!keys.empty() && !values.empty() && !qkv_results.empty(),
    //                "keys, values and qkv_results must not be empty");
    //    TORCH_CHECK(keys.size() == values.size() && keys.size() == qkv_results.size(),
    //                "keys, values and qkv_results must have same size");
    //    TORCH_CHECK(score_max_.size() == score_sum_.size(),
    //                "score_max_ and score_sum_ must have same size");
    //    TORCH_CHECK(chunk_size % 16 == 0, "chunk size must be multiple of 16");

    //    size_t max_sequence_length = seq_length_.max().item<int>();
    //    TORCH_CHECK(seq_chunk_mapping_stride > 0, "seq_chunk_mapping_ must be contiguous");
    //    CHECK_SHAPE(query, { num_heads, num_seqs, head_dim });
    //    CHECK_SHAPE(output, { num_heads, num_seqs, head_dim });
    //    for (auto& key : keys)
    //        CHECK_SHAPE(key, { num_heads, chunk_size, head_dim });
    //    for (auto& value : values)
    //        CHECK_SHAPE(value, { num_heads, chunk_size, head_dim });

    float dim_scale = 1.f / sqrtf(static_cast<float>(head_dim));

    //    size_t query_shared_mem_size = query.nbytes() / num_heads / num_seqs;
    //    size_t score_shared_mem_size = max_sequence_length * sizeof(float);
    //    size_t shared_mem_size = query_shared_mem_size + score_shared_mem_size;
    //    std::cout << "query per sm size: " << query_shared_mem_size << std::endl;
    //    std::cout << "score per sm size: " << score_shared_mem_size << std::endl;
    //    std::cout << "shared_mem_size: " << shared_mem_size << std::endl;
    int shared_mem_size = 4 * (head_dim * chunk_size * sizeof(half) + chunk_size * 16);
    dim3 grid(num_heads, num_seqs);
    const cudaStream_t stream = at::cuda::getCurrentCUDAStream();
    //    std::cout << "Grid: " << grid.x << " " << grid.y << std::endl;
    construct_tensor_ptr_list(keys, values, qkv_results, score_max, score_sum);
    copy_seq_chunk_mapping(seq_chunk_mapping, seq_length);
    auto start = std::chrono::high_resolution_clock::now();
#define CALL_ATTN_SEQ_KERNEL_FUNCTION(scalar_t, chunk_size, head_dim)                              \
    {                                                                                              \
        if (shared_mem_size >= 48 * 1024) {                                                        \
            C10_CUDA_CHECK(                                                                        \
              cudaFuncSetAttribute(attn_seq_first_kernel<scalar_t, chunk_size, head_dim>,          \
                                   cudaFuncAttributeMaxDynamicSharedMemorySize,                    \
                                   shared_mem_size));                                              \
        }                                                                                          \
        attn_seq_first_kernel<scalar_t, chunk_size, head_dim>                                      \
          <<<grid, 128, shared_mem_size, stream>>>(query_ptr,                                      \
                                                   output_ptr,                                     \
                                                   keys_ptr,                                       \
                                                   values_ptr,                                     \
                                                   qkv_results_ptr,                                \
                                                   score_max_ptr,                                  \
                                                   score_sum_ptr,                                  \
                                                   seq_chunk_mapping_ptr,                          \
                                                   seq_length_ptr,                                 \
                                                   q_head_stride,                                  \
                                                   kv_head_stride,                                 \
                                                   max_chunk_num_,                                 \
                                                   dim_scale);                                     \
    }

    std::chrono::high_resolution_clock::time_point start_t, end_t;
    if (trace != nullptr && trace->record_kernel_t) {
        cudaDeviceSynchronize();
        start_t = std::chrono::high_resolution_clock::now();
    }
    if (query.dtype() == at::ScalarType::Float) {
        throw std::runtime_error("Unsupported data type: float");
    } else if (query.dtype() == at::ScalarType::Half) {
        auto query_ptr = reinterpret_cast<half*>(query.data_ptr());
        auto output_ptr = reinterpret_cast<half*>(output.data_ptr());
        auto keys_ptr = this->key_list_;
        auto values_ptr = this->value_list_;
        auto qkv_results_ptr = this->qkv_result_list;
        auto score_max_ptr = this->score_max_list;
        auto score_sum_ptr = this->score_sum_list;
        auto seq_chunk_mapping_ptr = reinterpret_cast<int*>(seq_chunk_mapping_);
        auto seq_length_ptr = reinterpret_cast<int*>(seq_length_);
        if (chunk_size == 32)
            CALL_ATTN_SEQ_KERNEL_FUNCTION(half, 32, 128)
        else if (chunk_size == 64)
            CALL_ATTN_SEQ_KERNEL_FUNCTION(half, 64, 128)
        else if (chunk_size == 128)
            CALL_ATTN_SEQ_KERNEL_FUNCTION(half, 128, 128)
        else
            throw std::runtime_error("Unsupported chunk_size");

    } else if (query.dtype() == at::ScalarType::BFloat16) {
        throw std::runtime_error("Unsupported data type: bfloat16");
    } else {
        TORCH_CHECK(false, "Unsupported data type: ", query.dtype());
    }
    if (trace != nullptr && trace->record_kernel_t) {
        cudaDeviceSynchronize();
        end_t = std::chrono::high_resolution_clock::now();
        auto duration = std::chrono::duration_cast<std::chrono::microseconds>(end_t - start_t);
        trace->seq_kernel_t = duration.count();
        // std::cout << "seq kernel time: " << duration.count() << std::endl;
    }
    C10_CUDA_KERNEL_LAUNCH_CHECK();
}
GPUKernel::~GPUKernel() {
    free(seq_chunk_mapping_cpu_);
    free(seq_length_cpu_);
    cudaFree(key_list_);
    cudaFree(value_list_);
    cudaFree(qkv_result_list);
    cudaFree(score_sum_list);
    cudaFree(score_max_list);
    cudaFree(score_sum_);
    cudaFree(score_max_);
    cudaFree(seq_chunk_mapping_);
    cudaFree(seq_length_);
}
GPUKernel::GPUKernel(int head_num, int max_chunk_num, int max_batch_size)
  : head_num_(head_num)
  , max_chunk_num_(max_chunk_num)
  , max_batch_size_(max_batch_size) {
    key_list_ = nullptr;
    value_list_ = nullptr;
    qkv_result_list = nullptr;
    score_sum_list = nullptr;
    score_max_list = nullptr;
    seq_chunk_mapping_cpu_ = new int[max_batch_size_ * max_chunk_num_];
    seq_length_cpu_ = new int[max_batch_size_];
    cudaMalloc(&key_list_, sizeof(void*) * max_chunk_num_);
    cudaMalloc(&value_list_, sizeof(void*) * max_chunk_num_);
    cudaMalloc(&qkv_result_list, sizeof(void*) * max_chunk_num_);
    cudaMalloc(&score_sum_list, sizeof(void*) * max_chunk_num_);
    cudaMalloc(&score_max_list, sizeof(void*) * max_chunk_num_);
    cudaMalloc(&score_sum_list, sizeof(void*) * max_chunk_num_);
    cudaMalloc(&score_max_, sizeof(float) * max_batch_size_ * head_num_);
    cudaMalloc(&score_sum_, sizeof(float) * max_batch_size_ * head_num_);
    cudaMalloc(&seq_chunk_mapping_, sizeof(void*) * max_batch_size_ * max_chunk_num_);
    cudaMalloc(&seq_length_, sizeof(void*) * max_batch_size_);
}

__host__ void GPUKernel::construct_tensor_ptr_list(
  const std::vector<torch::Tensor>& key_tensor_list,
  const std::vector<torch::Tensor>& value_tensor_list,
  const std::vector<torch::Tensor>& qkv_result_tensor_list,
  const std::vector<torch::Tensor>& score_max_tensor_list,
  const std::vector<torch::Tensor>& score_sum_tensor_list) {
    void* tmp_key_list[MAX_CHUNK_NUM];
    void* tmp_value_list[MAX_CHUNK_NUM];
    void* tmp_qkv_result_list[MAX_CHUNK_NUM];
    void* tmp_score_sum_list[MAX_CHUNK_NUM];
    void* tmp_score_max_list[MAX_CHUNK_NUM];
    for (int i = 0; i < key_tensor_list.size(); i++) {
        tmp_key_list[i] = key_tensor_list[i].defined() ? key_tensor_list[i].data_ptr() : nullptr;
        tmp_value_list[i] =
          value_tensor_list[i].defined() ? value_tensor_list[i].data_ptr() : nullptr;
        tmp_qkv_result_list[i] =
          qkv_result_tensor_list[i].defined() ? qkv_result_tensor_list[i].data_ptr() : nullptr;
        tmp_score_sum_list[i] =
          score_sum_tensor_list[i].defined() ? score_sum_tensor_list[i].data_ptr() : nullptr;
        tmp_score_max_list[i] =
          score_max_tensor_list[i].defined() ? score_max_tensor_list[i].data_ptr() : nullptr;
    }
    cudaMemcpy(this->key_list_,
               tmp_key_list,
               sizeof(void*) * key_tensor_list.size(),
               cudaMemcpyHostToDevice);
    cudaMemcpy(this->value_list_,
               tmp_value_list,
               sizeof(void*) * value_tensor_list.size(),
               cudaMemcpyHostToDevice);
    cudaMemcpy(this->qkv_result_list,
               tmp_qkv_result_list,
               sizeof(void*) * qkv_result_tensor_list.size(),
               cudaMemcpyHostToDevice);
    cudaMemcpy(this->score_sum_list,
               tmp_score_sum_list,
               sizeof(void*) * score_sum_tensor_list.size(),
               cudaMemcpyHostToDevice);
    cudaMemcpy(this->score_max_list,
               tmp_score_max_list,
               sizeof(void*) * score_max_tensor_list.size(),
               cudaMemcpyHostToDevice);
}
void GPUKernel::copy_seq_chunk_mapping(const std::vector<std::vector<int>>& seq_chunk_mapping,
                                       const std::vector<int>& seq_length) {
    for (int i = 0; i < seq_chunk_mapping.size(); i++) {
        memcpy(seq_chunk_mapping_cpu_ + i * max_chunk_num_,
               seq_chunk_mapping[i].data(),
               sizeof(int) * seq_chunk_mapping[i].size());
    }
    memcpy(seq_length_cpu_, seq_length.data(), sizeof(int) * seq_length.size());
    cudaMemcpy(seq_chunk_mapping_,
               seq_chunk_mapping_cpu_,
               sizeof(int) * seq_chunk_mapping.size() * max_chunk_num_,
               cudaMemcpyHostToDevice);
    cudaMemcpy(
      seq_length_, seq_length_cpu_, sizeof(int) * seq_length.size(), cudaMemcpyHostToDevice);
}

void GPUKernel::attention(torch::Tensor q,
                          std::vector<Task>& tasks,
                          torch::Tensor& output,
                          torch::TensorOptions& kv_options,
                          int partition,
                          Trace* trace) {
    auto total_start = std::chrono::high_resolution_clock::now();
    int num_heads = tasks[0].chunk->key().size(0);
    int num_seqs = q.size(1);
    // trunk first, then seq first
    std::vector<torch::Tensor> keys;
    std::vector<torch::Tensor> values;
    std::vector<torch::Tensor> score_maxs;
    std::vector<torch::Tensor> score_sums;
    std::vector<torch::Tensor> qkv_results;
    std::vector<int> starts;
    std::vector<int> ends;

    torch::TensorOptions int_options = kv_options.dtype(torch::kInt32);
    torch::TensorOptions float_options = kv_options.dtype(torch::kFloat32);

    if (trace != nullptr) {
        trace->chunk_kernel_t = 0.0;
        trace->seq_kernel_t = 0.0;
    }
    if (partition == 0 || partition == 1) {
        for (auto& task : tasks) {
            if (task.seq_idx_end - task.seq_idx_begin == num_seqs) {
                keys.push_back(task.chunk->key());
                values.push_back(task.chunk->value());
                starts.push_back(task.seq_idx_begin);
                ends.push_back(task.seq_idx_end);
                task.score_max = torch::empty({ num_heads, num_seqs }, float_options);
                score_maxs.push_back(task.score_max);
                task.score_sum = torch::empty({ num_heads, num_seqs }, float_options);
                score_sums.push_back(task.score_sum);
                task.qkv_result = torch::empty(q.sizes(), kv_options);
                qkv_results.push_back(task.qkv_result);
            }
        }
        // move to cuda

        if (!keys.empty()) {
            torch::Tensor starts_tensor =
              at::from_blob(starts.data(), { static_cast<long long>(starts.size()) }, torch::kInt32)
                .to(int_options);
            torch::Tensor ends_tensor =
              at::from_blob(ends.data(), { static_cast<long long>(ends.size()) }, torch::kInt32)
                .to(int_options);
            // record kernel time
            auto start = std::chrono::high_resolution_clock::now();
            attn_chunks_first(q,
                              keys,
                              values,
                              qkv_results,
                              starts_tensor,
                              ends_tensor,
                              score_maxs,
                              score_sums,
                              trace);
            auto end = std::chrono::high_resolution_clock::now();
            auto duration = std::chrono::duration_cast<std::chrono::microseconds>(end - start);
            // std::cout << "attn_chunks_first time: " << duration.count()
            //           << " chunk size: " << keys.size() << std::endl;
        }
    }
    keys.clear();
    values.clear();
    score_maxs.clear();
    score_sums.clear();
    qkv_results.clear();

    //    int seq_chunk_mapping[MAX_BATCH_SIZE][MAX_CHUNK_NUM];
    //    int seq_chunk_num[MAX_BATCH_SIZE];
    //    int seq_length[MAX_BATCH_SIZE];
    //    memset(seq_length, 0, sizeof(int) * MAX_BATCH_SIZE);
    //    memset(seq_chunk_num, 0, sizeof(int) * MAX_BATCH_SIZE);
    std::vector<std::vector<int>> seq_chunk_mapping(num_seqs, std::vector<int>());
    std::vector<int> seq_length(num_seqs, 0);

    for (int i = 0; i < tasks.size(); i++) {
        const Task& task = tasks[i];
        for (int j = task.seq_idx_begin; j < task.seq_idx_end; j++) {
            seq_chunk_mapping[j].push_back(i);
            seq_length[j] += task.chunk->n_tokens();
        }
        keys.push_back(task.chunk->key());
        values.push_back(task.chunk->value());
        score_maxs.push_back(task.score_max);
        score_sums.push_back(task.score_sum);
        qkv_results.push_back(task.qkv_result);
    }

    attn_seqs_first(q,
                    output,
                    keys,
                    values,
                    qkv_results,
                    score_maxs,
                    score_sums,
                    seq_chunk_mapping,
                    seq_length,
                    trace);
}
}
