/*
 * Copyright (c) 2023 by FlashInfer team.
 *
 * Licensed under the Apache License, Version 2.0 (the "License");
 * you may not use this file except in compliance with the License.
 * You may obtain a copy of the License at
 *
 *   http://www.apache.org/licenses/LICENSE-2.0
 *
 * Unless required by applicable law or agreed to in writing, software
 * distributed under the License is distributed on an "AS IS" BASIS,
 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 * See the License for the specific language governing permissions and
 * limitations under the License.
 */
#include <gtest/gtest.h>

#include <decode/decode_page.cuh>
#include <type_traits>

#include "cpu_reference.h"
#include "cpu_utils.h"

using namespace flashinfer;

template <QKVLayout kv_layout, typename T, size_t head_dim>
void _TestAppendPagedKVPrefillKernelCorrectness(size_t page_size,
												size_t num_heads,
												size_t seq_len) {
	constexpr size_t batch_size = 1;
	size_t max_num_pages = flashinfer::ceil_div(seq_len, page_size);

	std::vector<T> kv_data_cpu(2 * max_num_pages * page_size * num_heads * head_dim);
	utils::vec_zero_(kv_data_cpu);
	thrust::device_vector<T> kv_data_gpu(kv_data_cpu);

	// Random page_indices
	std::vector<int32_t> page_indices(max_num_pages);
	std::iota(page_indices.begin(), page_indices.end(), 0);
	std::shuffle(page_indices.begin(), page_indices.end(), std::mt19937(std::random_device()()));
	int32_t last_page_idx = page_indices.back();
	int32_t last_page_len = (seq_len - 1) % page_size + 1;

	std::vector<int32_t> page_indptr({0, static_cast<int32_t>(max_num_pages)});
	std::vector<int32_t> append_indptr({0, static_cast<int32_t>(seq_len)});
	std::vector<T> ki(seq_len * num_heads * head_dim);
	std::vector<T> vi(seq_len * num_heads * head_dim);
	utils::vec_normal_(ki);
	utils::vec_normal_(vi);

	// CPU Append
	std::vector<std::vector<T>> keys({ki});
	std::vector<std::vector<T>> values({vi});
	paged_kv_t<PageStorage::kIndices, kv_layout, T, int32_t> paged_kv_cpu(
		num_heads,
		page_size,
		head_dim,
		batch_size,
		0, // page_budget, useless when appending
		last_page_len,
		last_page_idx,
		kv_data_cpu.data(),
		page_indices.data(),
		page_indptr.data());
	cpu_reference::append_paged_kv_cache<kv_layout>(paged_kv_cpu, keys, values, append_indptr);

	// CPU candidate chunk append
	size_t candidate_max_num_chunks = flashinfer::ceil_div(max_num_pages, page_size);
	size_t candidate_last_chunk_len = (max_num_pages - 1) % page_size + 1;
	std::vector<int32_t> candidate_chunk_indices(candidate_max_num_chunks);
	std::iota(candidate_chunk_indices.begin(), candidate_chunk_indices.end(), 0);
	std::shuffle(candidate_chunk_indices.begin(),
				 candidate_chunk_indices.end(),
				 std::mt19937(std::random_device()()));
	int32_t candidate_last_chunk_idx = candidate_chunk_indices.back();
	std::vector<int32_t> candidate_chunk_indptr(
		{0, static_cast<int32_t>(candidate_max_num_chunks)});

	std::vector<T> candidate_chunk_data_cpu(candidate_max_num_chunks * 2 * num_heads * head_dim *
											page_size);

	// Here we initialize tensor as zero.
	// Therefore, later when we execute online maintainence. It will lead to same result.
	// However, zero is not accurate. We maintain it in kernel-level specifically.
	utils::vec_zero_(candidate_chunk_data_cpu);
	thrust::device_vector<T> candidate_chunk_data_gpu(candidate_chunk_data_cpu);

	std::vector<T> candidate_max(max_num_pages * num_heads * head_dim);
	utils::vec_fill_(candidate_max, -CUDART_MAX_NORMAL_FP16);
	std::vector<T> candidate_min(max_num_pages * num_heads * head_dim);
	utils::vec_fill_(candidate_min, CUDART_MAX_NORMAL_FP16);

	// Reduction on K value. Layout is NHD for naturally generated by GEMM
	tensor_info_t<QKVLayout::kNHD, 1, head_dim> page_info(1, seq_len, num_heads);
	tensor_info_t<QKVLayout::kNHD, 1, head_dim> chunk_info(1, max_num_pages, num_heads);
	for(size_t head_idx = 0; head_idx < num_heads; ++head_idx) {
		for(size_t page_idx = 0; page_idx < max_num_pages; ++page_idx) {
			size_t cur_page_len = (page_idx == max_num_pages - 1) ? last_page_len : page_size;
			for(int32_t i = 0; i < cur_page_len; ++i) {
				int32_t entry_idx = page_idx * page_size + i;
				for(int32_t feat_idx = 0; feat_idx < head_dim; ++feat_idx) {
					T local_val = ki[page_info.get_kv_elem_offset(entry_idx, head_idx, feat_idx)];
					T local_max =
						candidate_max[chunk_info.get_kv_elem_offset(page_idx, head_idx, feat_idx)];
					T local_min =
						candidate_min[chunk_info.get_kv_elem_offset(page_idx, head_idx, feat_idx)];
					candidate_max[chunk_info.get_kv_elem_offset(page_idx, head_idx, feat_idx)] =
						(local_val > local_max) ? local_val : local_max;
					candidate_min[chunk_info.get_kv_elem_offset(page_idx, head_idx, feat_idx)] =
						(local_val < local_min) ? local_val : local_min;
				}
			}
		}
	}

	std::vector<std::vector<T>> chunk_keys({candidate_max});
	std::vector<std::vector<T>> chunk_values({candidate_min});
	std::vector<int32_t> chunk_append_indptr({0, static_cast<int32_t>(max_num_pages)});
	paged_kv_t<PageStorage::kIndices, kv_layout, T, int32_t> candidate_kv_cpu(
		num_heads,
		page_size,
		head_dim,
		batch_size,
		0, // page_budget, useless when appending
		candidate_last_chunk_len,
		candidate_last_chunk_idx,
		candidate_chunk_data_cpu.data(),
		candidate_chunk_indices.data(),
		candidate_chunk_indptr.data());
	cpu_reference::append_paged_kv_cache<kv_layout>(
		candidate_kv_cpu, chunk_keys, chunk_values, chunk_append_indptr);

	// GPU Append
	thrust::device_vector<int32_t> indptr_gpu(page_indptr);
	thrust::device_vector<int32_t> indices_gpu(page_indices);
	paged_kv_t<PageStorage::kIndices, kv_layout, T, int32_t> paged_kv_gpu(
		num_heads,
		page_size,
		head_dim,
		batch_size,
		0,
		last_page_len,
		last_page_idx,
		thrust::raw_pointer_cast(kv_data_gpu.data()),
		thrust::raw_pointer_cast(indices_gpu.data()),
		thrust::raw_pointer_cast(indptr_gpu.data()));

	thrust::device_vector<int32_t> candidate_chunk_indptr_gpu(candidate_chunk_indptr);
	thrust::device_vector<int32_t> candidate_chunk_indices_gpu(candidate_chunk_indices);
	paged_kv_t<PageStorage::kIndices, kv_layout, T, int32_t> candidate_kv_gpu(
		num_heads,
		page_size,
		head_dim,
		batch_size,
		0,
		candidate_last_chunk_len,
		candidate_last_chunk_idx,
		thrust::raw_pointer_cast(candidate_chunk_data_gpu.data()),
		thrust::raw_pointer_cast(candidate_chunk_indices_gpu.data()),
		thrust::raw_pointer_cast(candidate_chunk_indptr_gpu.data()));

	thrust::device_vector<int32_t> append_indptr_gpu(append_indptr);
	thrust::device_vector<T> keys_gpu(ki);
	thrust::device_vector<T> values_gpu(vi);

	cudaError_t status =
		AppendPagedKVCachePrefill(paged_kv_gpu,
								  candidate_kv_gpu,
								  thrust::raw_pointer_cast(keys_gpu.data()),
								  thrust::raw_pointer_cast(values_gpu.data()),
								  thrust::raw_pointer_cast(append_indptr_gpu.data()));
	EXPECT_EQ(status, cudaSuccess)
		<< "AppendPagedKVCachePrefill kernel launch failed, error message: "
		<< cudaGetErrorString(status);

	thrust::host_vector<T> kv_data_gpu_h(kv_data_gpu);
	thrust::host_vector<T> candidate_chunk_data_gpu_h(candidate_chunk_data_gpu);

	size_t num_result_errors_atol_1e_3_rtol_1e_3 = 0;
	bool nan_detected = false;
	for(size_t i = 0; i < kv_data_cpu.size(); ++i) {
		if(std::isnan(float(kv_data_gpu_h[i]))) {
			nan_detected = true;
		}
		num_result_errors_atol_1e_3_rtol_1e_3 +=
			(!utils::isclose(float(kv_data_cpu[i]), float(kv_data_gpu_h[i]), 1e-3, 1e-3));
	}
	for(size_t i = 0; i < candidate_chunk_data_cpu.size(); ++i) {
		if(std::isnan(float(candidate_chunk_data_gpu_h[i]))) {
			nan_detected = true;
		}
		num_result_errors_atol_1e_3_rtol_1e_3 += (!utils::isclose(
			float(candidate_chunk_data_cpu[i]), float(candidate_chunk_data_gpu_h[i]), 1e-3, 1e-3));

		if(!utils::isclose(float(candidate_chunk_data_cpu[i]),
						   float(candidate_chunk_data_gpu_h[i]),
						   1e-3,
						   1e-3)) {
			std::cout << "i=" << i << ", cpu=" << float(candidate_chunk_data_cpu[i])
					  << ", gpu=" << float(candidate_chunk_data_gpu_h[i]) << std::endl;
		}
	}
	float result_accuracy = 1. - float(num_result_errors_atol_1e_3_rtol_1e_3) /
									 float(candidate_chunk_data_cpu.size() + kv_data_cpu.size());
	std::cout << "kv_layout=" << QKVLayoutToString(kv_layout) << ", page_size=" << page_size
			  << ", seq_len=" << seq_len << ", batch_size=" << batch_size
			  << ", num_heads=" << num_heads << ", head_dim=" << head_dim
			  << ", result_accuracy=" << result_accuracy << std::endl;
	EXPECT_GT(result_accuracy, 0.99) << "Result correctness test failed.";
	EXPECT_EQ(nan_detected, false) << "Nan detected in the result.";
}

template <QKVLayout kv_layout, typename T, size_t head_dim>
void _TestAppendPagedKVDecodeKernelCorrectness(size_t page_size, size_t num_heads, size_t seq_len) {
	constexpr size_t batch_size = 1;
	size_t seq_len_appened = seq_len + 1;
	size_t max_num_pages = flashinfer::ceil_div(seq_len_appened, page_size);

	std::vector<T> kv_data_cpu(2 * max_num_pages * page_size * num_heads * head_dim);
	utils::vec_normal_(kv_data_cpu);
	thrust::device_vector<T> kv_data_gpu(kv_data_cpu);

	// Random page_indices
	std::vector<int32_t> page_indices(max_num_pages);
	std::iota(page_indices.begin(), page_indices.end(), 0);
	std::shuffle(page_indices.begin(), page_indices.end(), std::mt19937(std::random_device()()));
	int32_t last_page_idx = page_indices.back();
	int32_t last_page_len = (seq_len_appened - 1) % page_size + 1;

	std::vector<int32_t> page_indptr({0, static_cast<int32_t>(max_num_pages)});
	std::vector<int32_t> append_indptr({0, 1});
	std::vector<T> ki(1 * num_heads * head_dim);
	std::vector<T> vi(1 * num_heads * head_dim);
	utils::vec_normal_(ki);
	utils::vec_normal_(vi);

	// CPU Append
	std::vector<std::vector<T>> keys({ki});
	std::vector<std::vector<T>> values({vi});
	paged_kv_t<PageStorage::kIndices, kv_layout, T, int32_t> paged_kv_cpu(
		num_heads,
		page_size,
		head_dim,
		batch_size,
		0, // page_budget, useless when appending
		last_page_len,
		last_page_idx,
		kv_data_cpu.data(),
		page_indices.data(),
		page_indptr.data());
	cpu_reference::append_paged_kv_cache<kv_layout>(paged_kv_cpu, keys, values, append_indptr);

	// CPU candidate chunk append
	size_t candidate_max_num_chunks = flashinfer::ceil_div(max_num_pages, page_size);
	size_t candidate_last_chunk_len = (max_num_pages - 1) % page_size + 1;
	std::vector<int32_t> candidate_chunk_indices(candidate_max_num_chunks);
	std::iota(candidate_chunk_indices.begin(), candidate_chunk_indices.end(), 0);
	std::shuffle(candidate_chunk_indices.begin(),
				 candidate_chunk_indices.end(),
				 std::mt19937(std::random_device()()));
	int32_t candidate_last_chunk_idx = candidate_chunk_indices.back();
	std::vector<int32_t> candidate_chunk_indptr(
		{0, static_cast<int32_t>(candidate_max_num_chunks)});

	std::vector<T> candidate_chunk_data_cpu(candidate_max_num_chunks * 2 * num_heads * head_dim *
											page_size);

	// We maintain the init value [-inf, inf] in kernel-level.
	utils::vec_normal_(candidate_chunk_data_cpu);
	thrust::device_vector<T> candidate_chunk_data_gpu(candidate_chunk_data_cpu);

	paged_kv_t<PageStorage::kIndices, kv_layout, T, int32_t> candidate_kv_cpu(
		num_heads,
		page_size,
		head_dim,
		batch_size,
		0, // page_budget, useless when appending
		candidate_last_chunk_len,
		candidate_last_chunk_idx,
		candidate_chunk_data_cpu.data(),
		candidate_chunk_indices.data(),
		candidate_chunk_indptr.data());

	// Operate on CPU metadata page_kv_t
	// ki, vi is NHD shape
	// Check whether it is a new page
	int32_t entry_idx = seq_len % page_size;
	for(size_t head_idx = 0; head_idx < num_heads; ++head_idx) {
		for(size_t feat_idx = 0; feat_idx < head_dim; ++feat_idx) {
			T local_val = ki[head_idx * head_dim + feat_idx];
			T& local_max = candidate_chunk_data_cpu[candidate_kv_cpu.get_k_elem_offset(
				candidate_last_chunk_idx, head_idx, candidate_last_chunk_len - 1, feat_idx)];
			T& local_min = candidate_chunk_data_cpu[candidate_kv_cpu.get_v_elem_offset(
				candidate_last_chunk_idx, head_idx, candidate_last_chunk_len - 1, feat_idx)];
			if(entry_idx == 0) {
				local_max = -CUDART_MAX_NORMAL_FP16;
				local_min = CUDART_MAX_NORMAL_FP16;
			}
			local_max = (local_val > local_max) ? local_val : local_max;
			local_min = (local_val < local_min) ? local_val : local_min;
		}
	}

	// GPU Append
	thrust::device_vector<int32_t> indptr_gpu(page_indptr);
	thrust::device_vector<int32_t> indices_gpu(page_indices);
	paged_kv_t<PageStorage::kIndices, kv_layout, T, int32_t> paged_kv_gpu(
		num_heads,
		page_size,
		head_dim,
		batch_size,
		0,
		last_page_len,
		last_page_idx,
		thrust::raw_pointer_cast(kv_data_gpu.data()),
		thrust::raw_pointer_cast(indices_gpu.data()),
		thrust::raw_pointer_cast(indptr_gpu.data()));

	thrust::device_vector<int32_t> candidate_chunk_indptr_gpu(candidate_chunk_indptr);
	thrust::device_vector<int32_t> candidate_chunk_indices_gpu(candidate_chunk_indices);
	paged_kv_t<PageStorage::kIndices, kv_layout, T, int32_t> candidate_kv_gpu(
		num_heads,
		page_size,
		head_dim,
		batch_size,
		0,
		candidate_last_chunk_len,
		candidate_last_chunk_idx,
		thrust::raw_pointer_cast(candidate_chunk_data_gpu.data()),
		thrust::raw_pointer_cast(candidate_chunk_indices_gpu.data()),
		thrust::raw_pointer_cast(candidate_chunk_indptr_gpu.data()));
	thrust::device_vector<T> keys_gpu(ki);
	thrust::device_vector<T> values_gpu(vi);

	cudaError_t status = AppendPagedKVCacheDecode(paged_kv_gpu,
												  candidate_kv_gpu,
												  thrust::raw_pointer_cast(keys_gpu.data()),
												  thrust::raw_pointer_cast(values_gpu.data()));
	EXPECT_EQ(status, cudaSuccess)
		<< "AppendPagedKVCachePrefill kernel launch failed, error message: "
		<< cudaGetErrorString(status);

	thrust::host_vector<T> kv_data_gpu_h(kv_data_gpu);
	thrust::host_vector<T> candidate_chunk_data_gpu_h(candidate_chunk_data_gpu);

	size_t num_result_errors_atol_1e_3_rtol_1e_3 = 0;
	bool nan_detected = false;
	for(size_t i = 0; i < kv_data_cpu.size(); ++i) {
		if(std::isnan(float(kv_data_gpu_h[i]))) {
			nan_detected = true;
		}
		num_result_errors_atol_1e_3_rtol_1e_3 +=
			(!utils::isclose(float(kv_data_cpu[i]), float(kv_data_gpu_h[i]), 1e-3, 1e-3));
	}
	for(size_t i = 0; i < candidate_chunk_data_cpu.size(); ++i) {
		if(std::isnan(float(candidate_chunk_data_gpu_h[i]))) {
			nan_detected = true;
		}
		num_result_errors_atol_1e_3_rtol_1e_3 += (!utils::isclose(
			float(candidate_chunk_data_cpu[i]), float(candidate_chunk_data_gpu_h[i]), 1e-3, 1e-3));

		if(!utils::isclose(float(candidate_chunk_data_cpu[i]),
						   float(candidate_chunk_data_gpu_h[i]),
						   1e-3,
						   1e-3)) {
			std::cout << "i=" << i << ", cpu=" << float(candidate_chunk_data_cpu[i])
					  << ", gpu=" << float(candidate_chunk_data_gpu_h[i]) << std::endl;
		}
	}
	float result_accuracy = 1. - float(num_result_errors_atol_1e_3_rtol_1e_3) /
									 float(candidate_chunk_data_cpu.size() + kv_data_cpu.size());
	std::cout << "kv_layout=" << QKVLayoutToString(kv_layout) << ", page_size=" << page_size
			  << ", seq_len=" << seq_len << ", batch_size=" << batch_size
			  << ", num_heads=" << num_heads << ", head_dim=" << head_dim
			  << ", result_accuracy=" << result_accuracy << std::endl;
	EXPECT_GT(result_accuracy, 0.99) << "Result correctness test failed.";
	EXPECT_EQ(nan_detected, false) << "Nan detected in the result.";
}

// Multi-round Q&A test
template <QKVLayout kv_layout, typename T, size_t head_dim>
void _TestAppendPagedKVKernelCorrectness(size_t page_size, size_t num_heads) {
	constexpr size_t batch_size = 1;
	// number of conversation rounds
	size_t num_conv_rounds = 5;
	size_t max_decode_len = 1;
	size_t max_prefill_len = 256;
	size_t max_num_pages =
		num_conv_rounds * batch_size * ((max_decode_len + max_prefill_len) / page_size + 1);

	std::vector<T> kv_data_cpu(2 * max_num_pages * page_size * num_heads * head_dim);
	utils::vec_zero_(kv_data_cpu);
	thrust::device_vector<T> kv_data_gpu(kv_data_cpu);

	std::vector<int32_t> seq_len(batch_size);
	utils::vec_fill_(seq_len, 0);
	std::vector<std::vector<int32_t>> page_indices(batch_size);
	std::vector<int32_t> last_page_len(batch_size);
	utils::vec_fill_(last_page_len, 0);
	size_t page_counter = 0;

	// calculate the configuration of metadata page
	size_t num_chunks = flashinfer::ceil_div(max_num_pages, page_size);
	std::vector<T> chunk_data_cpu(2 * num_chunks * page_size * num_heads * head_dim);
	utils::vec_zero_(chunk_data_cpu);
	thrust::device_vector<T> chunk_data_gpu(chunk_data_cpu);

	std::vector<int32_t> chunk_indices(num_chunks);
	std::iota(chunk_indices.begin(), chunk_indices.end(), 0);
	for(size_t round = 0; round < 2 * num_conv_rounds; ++round) {
		std::vector<int32_t> append_len(batch_size);
		std::vector<int32_t> append_indptr{0};
		std::vector<std::vector<T>> keys;
		std::vector<std::vector<T>> values;
		if(round % 2 == 0) {
			utils::vec_randint_(append_len, 1, max_prefill_len + 1);
		} else {
			utils::vec_fill_<int32_t>(append_len, max_decode_len);
		}
		for(size_t i = 0; i < batch_size; ++i) {
			append_indptr.push_back(append_indptr.back() + append_len[i]);
			seq_len[i] += append_len[i];
			for(size_t j = 0; j < append_len[i]; ++j) {
				if(last_page_len[i] % page_size == 0) {
					page_indices[i].push_back(page_counter++);
					last_page_len[i] = 1;
				} else {
					last_page_len[i] += 1;
				}
			}
			std::vector<T> ki(append_len[i] * num_heads * head_dim),
				vi(append_len[i] * num_heads * head_dim);
			utils::vec_normal_(ki);
			utils::vec_normal_(vi);
			keys.push_back(ki);
			values.push_back(vi);
		}

		std::vector<int32_t> indptr_cpu{0};
		std::vector<int32_t> indices_cpu;
		for(size_t i = 0; i < batch_size; ++i) {
			for(size_t j = 0; j < page_indices[i].size(); ++j) {
				indices_cpu.push_back(page_indices[i][j]);
			}
			indptr_cpu.push_back(indptr_cpu.back() + page_indices[i].size());
		}

		paged_kv_t<PageStorage::kIndices, kv_layout, T, int32_t> paged_kv_cpu(
			num_heads,
			page_size,
			head_dim,
			batch_size,
			0,
			last_page_len[0],
			page_indices[0].back(),
			kv_data_cpu.data(),
			indices_cpu.data(),
			indptr_cpu.data());
		cpu_reference::append_paged_kv_cache<kv_layout>(paged_kv_cpu, keys, values, append_indptr);

		thrust::device_vector<int32_t> indptr_gpu(indptr_cpu);
		thrust::device_vector<int32_t> indices_gpu(indices_cpu);
		paged_kv_t<PageStorage::kIndices, kv_layout, T, int32_t> paged_kv_gpu(
			num_heads,
			page_size,
			head_dim,
			batch_size,
			0,
			last_page_len[0],
			page_indices[0].back(),
			thrust::raw_pointer_cast(kv_data_gpu.data()),
			thrust::raw_pointer_cast(indices_gpu.data()),
			thrust::raw_pointer_cast(indptr_gpu.data()));

		thrust::device_vector<int32_t> append_indptr_gpu(append_indptr);
		thrust::device_vector<T> keys_gpu(append_indptr.back() * num_heads * head_dim);
		thrust::device_vector<T> values_gpu(append_indptr.back() * num_heads * head_dim);
		for(size_t i = 0; i < batch_size; ++i) {
			thrust::device_vector<T> ki(keys[i]);
			thrust::device_vector<T> vi(values[i]);
			thrust::copy(
				ki.begin(), ki.end(), keys_gpu.begin() + append_indptr[i] * num_heads * head_dim);
			thrust::copy(
				vi.begin(), vi.end(), values_gpu.begin() + append_indptr[i] * num_heads * head_dim);
		}

		// collect information for appending metadata
		size_t cur_page_size = page_indices[0].size();
		assert(cur_page_size == page_counter);
		size_t cur_chunk_size = flashinfer::ceil_div(cur_page_size, page_size);
		std::vector<int32_t> chunk_indptr({0, static_cast<int32_t>(cur_chunk_size)});
		assert(cur_chunk_size <= num_chunks);
		size_t last_chunk_idx = chunk_indices[chunk_indptr[1] - 1];
		assert(last_chunk_idx == cur_chunk_size - 1);
		size_t last_chunk_len = (cur_page_size - 1) % page_size + 1;

		thrust::device_vector<int32_t> chunk_indptr_gpu(chunk_indptr);
		thrust::device_vector<int32_t> chunk_indices_gpu(chunk_indices);
		paged_kv_t<PageStorage::kIndices, kv_layout, T, int32_t> chunk_kv_gpu(
			num_heads,
			page_size,
			head_dim,
			batch_size,
			0,
			last_chunk_len,
			last_chunk_idx,
			thrust::raw_pointer_cast(chunk_data_gpu.data()),
			thrust::raw_pointer_cast(chunk_indices_gpu.data()),
			thrust::raw_pointer_cast(chunk_indptr_gpu.data()));

		if(round % 2 == 0) {
			// call prefill kernel
			cudaError_t status =
				AppendPagedKVCachePrefill(paged_kv_gpu,
										  chunk_kv_gpu,
										  thrust::raw_pointer_cast(keys_gpu.data()),
										  thrust::raw_pointer_cast(values_gpu.data()),
										  thrust::raw_pointer_cast(append_indptr_gpu.data()));
			EXPECT_EQ(status, cudaSuccess)
				<< "AppendPagedKVCachePrefill kernel launch failed, error message: "
				<< cudaGetErrorString(status);
		} else {
			// call decode kernel
			cudaError_t status =
				AppendPagedKVCacheDecode(paged_kv_gpu,
										 chunk_kv_gpu,
										 thrust::raw_pointer_cast(keys_gpu.data()),
										 thrust::raw_pointer_cast(values_gpu.data()));
			EXPECT_EQ(status, cudaSuccess)
				<< "AppendPagedKVCacheDecode kernel launch failed, error message: "
				<< cudaGetErrorString(status);
		}
	}
	// CPU Append for metadata
	// page_counter: appended page in total
	// last_page_len[0]: last page length
	std::vector<T> candidate_max(page_counter * num_heads * head_dim);
	utils::vec_fill_(candidate_max, -CUDART_MAX_NORMAL_FP16);
	std::vector<T> candidate_min(page_counter * num_heads * head_dim);
	utils::vec_fill_(candidate_min, CUDART_MAX_NORMAL_FP16);
	// Reduction on K value. Layout is NHD for naturally generated by GEMM
	tensor_info_t<QKVLayout::kNHD, 1, head_dim> chunk_info(1, page_counter, num_heads);
	paged_kv_t<PageStorage::kIndices, kv_layout, T, int32_t> paged_kv_cpu(
		num_heads,
		page_size,
		head_dim,
		batch_size,
		0,
		0,
		0,
		nullptr,
		nullptr,
		nullptr); // just for index

	for(size_t head_idx = 0; head_idx < num_heads; ++head_idx) {
		for(size_t page_idx = 0; page_idx < page_counter; ++page_idx) {
			size_t cur_page_len = (page_idx == page_counter - 1) ? last_page_len[0] : page_size;
			for(int32_t i = 0; i < cur_page_len; ++i) {
				for(int32_t feat_idx = 0; feat_idx < head_dim; ++feat_idx) {
					T local_val = kv_data_cpu[paged_kv_cpu.get_k_elem_offset(
						page_idx, head_idx, i, feat_idx)];
					T& local_max =
						candidate_max[chunk_info.get_kv_elem_offset(page_idx, head_idx, feat_idx)];
					T& local_min =
						candidate_min[chunk_info.get_kv_elem_offset(page_idx, head_idx, feat_idx)];
					local_max = (local_val > local_max) ? local_val : local_max;
					local_min = (local_val < local_min) ? local_val : local_min;
				}
			}
		}
	}

	std::vector<std::vector<T>> chunk_keys({candidate_max});
	std::vector<std::vector<T>> chunk_values({candidate_min});
	std::vector<int32_t> chunk_append_indptr({0, static_cast<int32_t>(page_counter)});
	size_t candidate_last_chunk_len = (page_counter - 1) % page_size + 1;
	size_t cur_chunk_size = flashinfer::ceil_div(page_counter, page_size);
	std::vector<int32_t> chunk_indptr({0, static_cast<int32_t>(cur_chunk_size)});
	paged_kv_t<PageStorage::kIndices, kv_layout, T, int32_t> candidate_kv_cpu(
		num_heads,
		page_size,
		head_dim,
		batch_size,
		0, // page_budget, useless when appending
		candidate_last_chunk_len,
		cur_chunk_size - 1,
		chunk_data_cpu.data(),
		chunk_indices.data(),
		chunk_indptr.data());
	cpu_reference::append_paged_kv_cache<kv_layout>(
		candidate_kv_cpu, chunk_keys, chunk_values, chunk_append_indptr);

	thrust::host_vector<T> kv_data_gpu_h(kv_data_gpu);
	thrust::host_vector<T> chunk_data_gpu_h(chunk_data_gpu);
	size_t num_result_errors_atol_1e_3_rtol_1e_3 = 0;
	bool nan_detected = false;
	for(size_t i = 0; i < kv_data_cpu.size(); ++i) {
		if(std::isnan(float(kv_data_gpu_h[i]))) {
			nan_detected = true;
		}
		num_result_errors_atol_1e_3_rtol_1e_3 +=
			(!utils::isclose(float(kv_data_cpu[i]), float(kv_data_gpu_h[i]), 1e-3, 1e-3));
	}
	for(size_t i = 0; i < chunk_data_cpu.size(); ++i) {
		if(std::isnan(float(chunk_data_gpu_h[i]))) {
			nan_detected = true;
		}
		num_result_errors_atol_1e_3_rtol_1e_3 +=
			(!utils::isclose(float(chunk_data_cpu[i]), float(chunk_data_gpu_h[i]), 1e-3, 1e-3));
	}
	float result_accuracy = 1. - float(num_result_errors_atol_1e_3_rtol_1e_3) /
									 float(kv_data_cpu.size() + chunk_data_cpu.size());
	std::cout << "kv_layout=" << QKVLayoutToString(kv_layout) << ", page_size=" << page_size
			  << ", batch_size=" << batch_size << ", num_heads=" << num_heads
			  << ", head_dim=" << head_dim << ", result_accuracy=" << result_accuracy << std::endl;
	EXPECT_GT(result_accuracy, 0.99) << "Result correctness test failed.";
	EXPECT_EQ(nan_detected, false) << "Nan detected in the result.";
}

template <typename T>
void TestAppendPagedKVPrefillKernelCorrectness() {
	for(size_t page_size : {1, 3, 7, 16, 31}) {
		for(size_t num_heads : {32}) {
			for(size_t seq_len : {17, 31, 71, 111, 330, 512, 1110, 4100}) {
				for(QKVLayout kv_layout : {QKVLayout::kNHD, QKVLayout::kHND}) {
					for(size_t head_dim : {64, 128, 256}) {
						SWITCH_HEAD_DIM(
							head_dim, HEAD_DIM, {SWITCH_LAYOUT(kv_layout, KV_LAYOUT, {
								_TestAppendPagedKVPrefillKernelCorrectness<KV_LAYOUT, T, HEAD_DIM>(
									page_size, num_heads, seq_len);
							})})
					}
				}
			}
		}
	}
}

template <typename T>
void TestAppendPagedKVDecodeKernelCorrectness() {
	for(size_t page_size : {1, 3, 7, 16, 31}) {
		for(size_t num_heads : {32}) {
			for(size_t seq_len : {17, 31, 71, 111, 330, 512, 1110, 4100}) {
				for(QKVLayout kv_layout : {QKVLayout::kNHD, QKVLayout::kHND}) {
					for(size_t head_dim : {64, 128, 256}) {
						SWITCH_HEAD_DIM(
							head_dim, HEAD_DIM, {SWITCH_LAYOUT(kv_layout, KV_LAYOUT, {
								_TestAppendPagedKVDecodeKernelCorrectness<KV_LAYOUT, T, HEAD_DIM>(
									page_size, num_heads, seq_len);
							})})
					}
				}
			}
		}
	}
}

template <typename T>
void TestAppendPagedKVKernelCorrectness() {
	for(size_t page_size : {1, 3, 7, 16, 31}) {
		for(size_t num_heads : {32}) {
			for(QKVLayout kv_layout : {QKVLayout::kNHD}) {
				for(size_t head_dim : {64, 128, 256}) {
					SWITCH_HEAD_DIM(head_dim, HEAD_DIM, {SWITCH_LAYOUT(kv_layout, KV_LAYOUT, {
										_TestAppendPagedKVKernelCorrectness<KV_LAYOUT, T, HEAD_DIM>(
											page_size, num_heads);
									})})
				}
			}
		}
	}
}

TEST(FlashInferCorrectnessTest, AppendPagedKVPrefillKernelCorrectnessTestFP16) {
	TestAppendPagedKVPrefillKernelCorrectness<half>();
}

TEST(FlashInferCorrectnessTest, AppendPagedKVDecodeKernelCorrectnessTestFP16) {
	TestAppendPagedKVDecodeKernelCorrectness<half>();
}

TEST(FlashInferCorrectnessTest, AppendPagedKVKernelCorrectnessTestFP16) {
	TestAppendPagedKVKernelCorrectness<half>();
}