#include <algorithm>
#include <random>

#include "gtest/gtest.h"
#include "attention.h"

TEST(TestPerf, Basic) {
    //torch::set_num_threads(1);

    int n_tokens = 1024;
    int n_requests = 32;
    int n_head = 32;
    int d_embed = 128;
    int chunk_size = 16;
    int n_decode_steps = 512;

    std::vector<int> tokens(n_tokens);
    std::iota(std::begin(tokens), std::end(tokens), 0);
    auto rng = std::default_random_engine{};

    GPT::Attention attn(
      n_head, d_embed, chunk_size, 2048, true, torch::kFloat32, torch::Device(torch::kCPU));
    for (int i = 0; i < n_requests; ++i) {
        std::shuffle(std::begin(tokens), std::end(tokens), rng);
        torch::Tensor k = torch::rand({n_head, n_tokens, d_embed});
        torch::Tensor v = torch::rand({ n_head, n_tokens, d_embed });
        attn.add_prompt(tokens, k, v);
    }
    std::vector<int> new_tokens(n_requests);
    std::iota(std::begin(new_tokens), std::end(new_tokens), 0);
    torch::Tensor new_k = torch::rand({ n_head, n_requests, d_embed });
    torch::Tensor new_v = torch::rand({ n_head, n_requests, d_embed });

    torch::Tensor q = torch::rand({ n_head, n_requests, d_embed });
    for (int i = 0; i < n_decode_steps; ++i) {
        attn.forward(q);
        //attn.append_completions(new_tokens, new_k, new_v);
    }
}