template <typename Engine, typename Layout, typename Engine1, typename Layout1>
inline __device__ void apply_sparse_mask_causal(Tensor<Engine, Layout> &tensor, Tensor<Engine1, Layout1> &flashmask_downstart, const uint32_t col_idx_offset_,
                                         const uint32_t max_seqlen_k, const uint32_t row_idx_offset_,
                                         const uint32_t warp_row_stride, const uint32_t mask_col_idx_offset) {
    // tensor has shape (ncol=(2, MMA_M), nrow=(2, MMA_N))
    static_assert(Layout::rank == 2, "Only support 2D Tensor");
    const uint32_t lane_id = threadIdx.x % 32;
    // const uint32_t row_idx_offset = row_idx_offset_ + lane_id / 4;
    const uint32_t row_idx_offset = row_idx_offset_;
    const uint32_t col_idx_offset = col_idx_offset_ + (lane_id % 4) * 2;
    #pragma unroll
    for (int nj = 0; nj < size<1, 1>(tensor); ++nj) {
        const uint32_t col_idx_base = col_idx_offset + nj * 8;
        #pragma unroll
        for (int j = 0; j < size<1, 0>(tensor); ++j) {
            const uint32_t col_idx = col_idx_base + j;
            const uint32_t start_row = flashmask_downstart(col_idx - mask_col_idx_offset);
            #pragma unroll
            for (int mi = 0; mi < size<0, 1>(tensor); ++mi) {
                const uint32_t row_idx_base = row_idx_offset + mi * warp_row_stride;
                #pragma unroll
                for (int i = 0; i < size<0, 0>(tensor); ++i) {
                    const uint32_t row_idx = row_idx_base + i * 8;
                    const uint32_t col_idx_limit = std::min(max_seqlen_k, row_idx + 1);
                    if (col_idx >= col_idx_limit) {
                        tensor(make_coord(i, mi), make_coord(j, nj)) = -INFINITY;
                    }
                    else if (col_idx < col_idx_limit && row_idx >= start_row) {
                        tensor(make_coord(i, mi), make_coord(j, nj)) = -INFINITY;
                    }
                }
            }
        }
    }
}

template <typename Engine, typename Layout, typename Engine1, typename Layout1>
inline __device__ void apply_sparse_mask_causal_withend(Tensor<Engine, Layout> &tensor, Tensor<Engine1, Layout1> &flashmask_downstart,  Tensor<Engine1, Layout1> &flashmask_downend, const uint32_t col_idx_offset_,
                                         const uint32_t max_seqlen_k, const uint32_t row_idx_offset_,
                                         const uint32_t warp_row_stride, const uint32_t mask_col_idx_offset) {
    // tensor has shape (ncol=(2, MMA_M), nrow=(2, MMA_N))
    static_assert(Layout::rank == 2, "Only support 2D Tensor");
    const uint32_t lane_id = threadIdx.x % 32;
    // const uint32_t row_idx_offset = row_idx_offset_ + lane_id / 4;
    const uint32_t row_idx_offset = row_idx_offset_;
    const uint32_t col_idx_offset = col_idx_offset_ + (lane_id % 4) * 2;
    #pragma unroll
    for (int nj = 0; nj < size<1, 1>(tensor); ++nj) {
        const uint32_t col_idx_base = col_idx_offset + nj * 8;
        #pragma unroll
        for (int j = 0; j < size<1, 0>(tensor); ++j) {
            const uint32_t col_idx = col_idx_base + j;
            const uint32_t start_row = flashmask_downstart(col_idx - mask_col_idx_offset);
            const uint32_t end_row = flashmask_downend(col_idx - mask_col_idx_offset);
            #pragma unroll
            for (int mi = 0; mi < size<0, 1>(tensor); ++mi) {
                const uint32_t row_idx_base = row_idx_offset + mi * warp_row_stride;
                #pragma unroll
                for (int i = 0; i < size<0, 0>(tensor); ++i) {
                    const uint32_t row_idx = row_idx_base + i * 8;
                    const uint32_t col_idx_limit = std::min(max_seqlen_k, row_idx + 1);
                    if (col_idx >= col_idx_limit) {
                        tensor(make_coord(i, mi), make_coord(j, nj)) = -INFINITY;
                    }
                    else if (col_idx < col_idx_limit && row_idx >= start_row && row_idx < end_row) {
                        tensor(make_coord(i, mi), make_coord(j, nj)) = -INFINITY;
                    }
                }
            }
        }
    }
}

template <typename Engine, typename Layout, typename Engine1, typename Layout1>
__forceinline__ __device__ void apply_sparse_mask(Tensor<Engine, Layout> &tensor,
                                         Tensor<Engine1, Layout1> &attn_mask_startrow_indices,
                                         Tensor<Engine1, Layout1> &attn_mask_endrow_indices,
                                         const uint32_t col_idx_offset_,
                                         const uint32_t max_seqlen_k,
                                         const uint32_t row_idx_offset_,
                                         const uint32_t warp_row_stride,
                                         const uint32_t mask_col_idx_offset,
                                         const bool pairwise) {
    // tensor has shape (ncol=(2, MMA_M), nrow=(2, MMA_N))
    static_assert(Layout::rank == 2, "Only support 2D Tensor");
    const uint32_t lane_id = threadIdx.x % 32;
    // const uint32_t row_idx_offset = row_idx_offset_ + lane_id / 4;
    const uint32_t row_idx_offset = row_idx_offset_;
    const uint32_t col_idx_offset = col_idx_offset_ + (lane_id % 4) * 2;
    #pragma unroll
    for (int nj = 0; nj < size<1, 1>(tensor); ++nj) {
        const uint32_t col_idx_base = col_idx_offset + nj * 8;
        #pragma unroll
        for (int j = 0; j < size<1, 0>(tensor); ++j) {
            const uint32_t col_idx = col_idx_base + j;
            const uint32_t startrow = attn_mask_startrow_indices(col_idx - mask_col_idx_offset);
            const uint32_t endrow = attn_mask_endrow_indices(col_idx - mask_col_idx_offset);

            #pragma unroll
            for (int mi = 0; mi < size<0, 1>(tensor); ++mi) {
                const uint32_t row_idx_base = row_idx_offset + mi * warp_row_stride;
                #pragma unroll
                for (int i = 0; i < size<0, 0>(tensor); ++i) {
                    const uint32_t row_idx = row_idx_base + i * 8;

                    if (col_idx >= max_seqlen_k) {
                      tensor(make_coord(i, mi), make_coord(j, nj)) = -INFINITY;
                    } else if (pairwise && (row_idx >= startrow && row_idx < endrow)) {
                      tensor(make_coord(i, mi), make_coord(j, nj)) = -INFINITY;
                    } else if (!pairwise && row_idx >= startrow) {
                      tensor(make_coord(i, mi), make_coord(j, nj)) = -INFINITY;
                    } else if (!pairwise && row_idx < endrow) {
                      tensor(make_coord(i, mi), make_coord(j, nj)) = -INFINITY;
                    }
                }
            }
        }
    }
}
