#pragma once
#include "task.h"
#include "trace.h"

# define MAX_CHUNK_NUM 4096
# define MAX_BATCH_SIZE 128

namespace GPT {

class GPUKernel {
  public:
    GPUKernel(int head_num = 32,
              int max_chunk_num = MAX_CHUNK_NUM,
              int max_batch_size = MAX_BATCH_SIZE);
    ~GPUKernel();
    void attention(torch::Tensor q,
                   std::vector<Task>& tasks,
                   torch::Tensor& output,
                   torch::TensorOptions& kv_options,
                   int partition = 0, // 0: auto, 1: chunk first, 2: sequence first
                   Trace* trace = nullptr);
    void attn_chunks_first(
      torch::Tensor& query,                    // [num_seqs, num_head, head_dim]
      std::vector<torch::Tensor>& keys,        // chunk_num<[num_head, chunk_size, head_dim]>
      std::vector<torch::Tensor>& values,      // chunk_num<[chunk_size,num_head, head_dim]>
      std::vector<torch::Tensor>& qkv_results, // chunk_num<[num_seqs,num_head, head_dim]>
      torch::Tensor& start,
      torch::Tensor& end,
      std::vector<torch::Tensor>& score_max,
      std::vector<torch::Tensor>& score_sum,
      Trace* trace = nullptr);
    void attn_seqs_first(
      const torch::Tensor& query,                    // [num_head, num_seqs, head_dim]
      torch::Tensor& output,                         // [num_head, num_seqs, head_dim]
      const std::vector<torch::Tensor>& keys,        // chunk_num<[num_head, chunk_size, head_dim]>
      const std::vector<torch::Tensor>& values,      // chunk_num<[num_head, chunk_size, head_dim]>
      const std::vector<torch::Tensor>& qkv_results, // chunk_num<[num_head, num_seqs, head_dim]>
      const std::vector<torch::Tensor>& score_max,
      const std::vector<torch::Tensor>& score_sum,
      const std::vector<std::vector<int>>& seq_chunk_mapping,
      const std::vector<int>& seq_length,
      Trace* trace = nullptr);

  private:
    void construct_tensor_ptr_list(const std::vector<torch::Tensor>& key_tensor_list,
                                   const std::vector<torch::Tensor>& value_tensor_list,
                                   const std::vector<torch::Tensor>& qkv_result_tensor_list,
                                   const std::vector<torch::Tensor>& score_max_tensor_list,
                                   const std::vector<torch::Tensor>& score_sum_tensor_list);

    void copy_seq_chunk_mapping(const std::vector<std::vector<int>>& seq_chunk_mapping,
                                const std::vector<int>& seq_length);

    int head_num_;
    int max_chunk_num_;
    int max_batch_size_;

    void** key_list_;
    void** value_list_;
    void** qkv_result_list;
    void** score_max_list;
    void** score_sum_list;
    void* score_max_;
    void* score_sum_;
    void* seq_chunk_mapping_;
    int* seq_chunk_mapping_cpu_;
    void* seq_length_;
    int* seq_length_cpu_;
};
}