#include "kernel_cuda.h"
#include "attention.h"
#include "gtest/gtest.h"
#include <random>

TEST(TestGPUKernel, attn_chunks_first) {
    GPT::GPUKernel kernel;
    constexpr uint32_t head_dim = 128;
    constexpr uint32_t chunk_size = 64;
    constexpr uint32_t chunk_num = 5;
    constexpr uint32_t num_head = 16;
    constexpr uint32_t num_seqs = 128;
    auto intput_type = torch::kFloat16;
    auto intput_options = at::device(at::Device(c10::DeviceType::CUDA)).dtype(intput_type);
    auto float_options = at::device(at::Device(c10::DeviceType::CUDA)).dtype(torch::kFloat32);
    auto int_options = at::device(at::Device(c10::DeviceType::CUDA)).dtype(torch::kInt32);
    torch::Tensor query = torch::randn({num_head, num_seqs, head_dim}, intput_options);
    std::vector<torch::Tensor> key;
    std::vector<torch::Tensor> value;
    std::vector<torch::Tensor> qkv_results;
    std::vector<torch::Tensor> score_maxs;
    std::vector<torch::Tensor> score_sums;
    value.reserve(chunk_num);
    key.reserve(chunk_num);
    qkv_results.reserve(chunk_num);
    for (int i = 0; i < chunk_num; i++) {
        key.push_back(torch::randn({num_head, chunk_size, head_dim}, intput_options));
        value.push_back(torch::randn({num_head, chunk_size, head_dim}, intput_options));
        qkv_results.push_back(torch::empty({num_head, num_seqs, head_dim}, intput_options));
        score_maxs.push_back(torch::empty({num_head, num_seqs}, float_options));
        score_sums.push_back(torch::empty({num_head, num_seqs}, float_options));
    }
    torch::Tensor start = torch::zeros({chunk_num}, int_options);
    torch::Tensor end = torch::zeros({chunk_num}, int_options);
    for (int i = 0; i < chunk_num; i++) {
        start[i] = 0;
        end[i] = static_cast<int>(num_seqs);
    }
    kernel.attn_chunks_first(query, key, value, qkv_results, start, end, score_maxs, score_sums);
    for(int i = 0; i < chunk_num; i++) {
        torch::Tensor score = torch::matmul(query.to(float_options), key[i].transpose(1, 2).to(float_options)) / std::sqrt(head_dim);
        torch::Tensor score_max = std::get<0>(score.max(2));
        torch::Tensor score_exp = torch::exp((score - score_max.unsqueeze(2)));
        torch::Tensor score_sum = score_exp.sum(2);
        torch::Tensor output = torch::matmul(score_exp.to(torch::kFloat16), value[i]);
        //        std::cout << torch::matmul(query[1], key[i][1].transpose(0, 1)) << std::endl;
        //        std::cout << score[1] << std::endl;
        // check result
        //        std::cout << score_max_ - score_maxs[i] <<std::endl;
        //        std::cout << score_sum_ - score_sums[i] <<std::endl;
        //        std::cout << output.sizes() << std::endl;
        //        std::cout << qkv_results[i].sizes() << std::endl;
        if (torch::allclose(score_max, score_maxs[i], 1e-3, 1e-3))
        {
            std::cout << "score_max_ pass" << std::endl;
        }
        else
        {
            std::cout << "score_max_ fail" << std::endl;
        }

        if (torch::allclose(score_sum, score_sums[i], 1e-1, 1e-1))
        {
            std::cout << "score_sum_ pass" << std::endl;
        }
        else
        {
            std::cout << "score_sum_ fail" << std::endl;
        }

        if (torch::allclose(output, qkv_results[i], 1e-2, 1e-2))
        {
            std::cout << "output pass" << std::endl;
        }
        else
        {
            std::cout << "output fail" << std::endl;
        }
    }
}

TEST(TestGPUKernel, attn_seq_first) {
    GPT::GPUKernel kernel;
    constexpr uint32_t head_dim = 128;
    constexpr uint32_t chunk_size = 128;
    constexpr uint32_t shared_chunk_num = 5;
    constexpr uint32_t unshared_chunk_num = 5;
    constexpr uint32_t num_head = 16;
    constexpr uint32_t num_seqs = 64;
    auto float16_options = at::device(at::Device(c10::DeviceType::CUDA)).dtype(torch::kFloat16);
    auto float_options = at::device(at::Device(c10::DeviceType::CUDA)).dtype(torch::kFloat32);
    auto int_options = at::device(at::Device(c10::DeviceType::CUDA)).dtype(torch::kInt32);
    torch::Tensor query = torch::randn({num_head, num_seqs, head_dim}, float16_options);
    std::vector<torch::Tensor> keys;
    std::vector<torch::Tensor> values;
    std::vector<torch::Tensor> qkv_results;
    std::vector<torch::Tensor> score_maxs;
    std::vector<torch::Tensor> score_sums;
    torch::Tensor start = torch::zeros({shared_chunk_num}, int_options);
    torch::Tensor end = torch::zeros({shared_chunk_num}, int_options);
    values.reserve(shared_chunk_num + num_seqs);
    keys.reserve(shared_chunk_num + num_seqs);
    qkv_results.reserve(shared_chunk_num + num_seqs);
    for (int i = 0; i < shared_chunk_num; i++) {
        keys.push_back(torch::randn({num_head, chunk_size, head_dim}, float16_options));
        values.push_back(torch::randn({num_head, chunk_size, head_dim}, float16_options));
        qkv_results.push_back(torch::randn({num_head, num_seqs, head_dim}, float16_options));
        score_maxs.push_back(torch::randn({num_head, num_seqs}, float_options));
        score_sums.push_back(torch::randn({num_head, num_seqs}, float_options));
    }
    kernel.attn_chunks_first(query, keys, values, qkv_results, start, end, score_maxs, score_sums);
    torch::Tensor output = torch::empty({num_head, num_seqs, head_dim}, float16_options);
    std::vector<std::vector<int>> seq_chunk_mapping(num_seqs, std::vector<int>());
    std::vector<int> seq_length(num_seqs, 0);

    for (int i = 0; i < num_seqs; i++) {
        for (int j = 0; j < shared_chunk_num; j++) {
            seq_chunk_mapping[i].push_back(j);
        }
        for (int j = shared_chunk_num; j < shared_chunk_num + unshared_chunk_num; j++) {
            keys.push_back(torch::randn({num_head, chunk_size, head_dim}, float16_options));
            values.push_back(torch::randn({num_head, chunk_size, head_dim}, float16_options));
            qkv_results.emplace_back();
            score_maxs.emplace_back();
            score_sums.emplace_back();
            seq_chunk_mapping[i].push_back(keys.size() - 1);
        }
        seq_length[i] = chunk_size * shared_chunk_num + chunk_size * unshared_chunk_num;
    }

    kernel.attn_seqs_first(query, output, keys, values, qkv_results, score_maxs, score_sums, seq_chunk_mapping, seq_length);
    for (int i = 0; i < num_seqs; i++)
    {
        std::vector <torch::Tensor> chunked_keys;
        std::vector <torch::Tensor> chunked_values;
        torch::Tensor single_q = query.slice(1, i, i + 1);
        for (int j = 0; j < shared_chunk_num + unshared_chunk_num; j++)
        {
            chunked_keys.push_back(keys[seq_chunk_mapping[i][j]]);
            chunked_values.push_back(values[seq_chunk_mapping[i][j]]);
        }
        torch::Tensor key = torch::cat(chunked_keys, 1);
        torch::Tensor value = torch::cat(chunked_values, 1);
        torch::Tensor qk = torch::matmul(single_q, key.transpose(1, 2)) / std::sqrt(head_dim);
        torch::Tensor score = torch::softmax(qk, 2);
        torch::Tensor expect_output = torch::matmul(score.to(torch::kFloat16), value);
        if (torch::allclose(expect_output, output.slice(1, i, i + 1), 1e-3, 1e-3))
        {
            std::cout << "output pass" << std::endl;
        }
        else
        {
            std::cout << "output fail" << std::endl;
//            std::cout << expect_output - output.slice(1, i, i + 1) << std::endl;
        }
    }
}

TEST(TestGPUKernel, forward_test) {
    int n_tokens = 1024;
    int n_requests = 48;
    int n_head = 32;
    int d_embed = 128;
    int chunk_size = 64;
    int n_decode_steps = 512;

    auto float16_options = at::device(at::Device(c10::DeviceType::CUDA)).dtype(torch::kFloat16);

    std::vector<int> tokens(n_tokens);
    std::iota(std::begin(tokens), std::end(tokens), 0);
    auto rng = std::default_random_engine{};

    GPT::Attention attn(n_head, d_embed, chunk_size, 2048, true, torch::kFloat16, torch::kCUDA);

    for (int i = 0; i < n_requests; ++i) {
        std::shuffle(std::begin(tokens), std::end(tokens), rng);
        torch::Tensor k = torch::randn({n_head, n_tokens, d_embed}, float16_options);
        torch::Tensor v = torch::randn({ n_head, n_tokens, d_embed}, float16_options);
        attn.add_prompt(tokens, k, v);
    }

    torch::Tensor q = torch::randn({ n_head, n_requests, d_embed }, float16_options);
    for (int i = 0; i < n_decode_steps; ++i) {
        attn.forward(q);
    }
}