#pragma once

#include "chunk.h"
#include "chunk_allocator.h"
#include "trace.h"
#include "task.h"
# ifdef USE_CUDA
#include "kernel_cuda.h"
# endif
#ifdef  USE_MKL
#include "kernel_cpu_mkl.h"
#endif

namespace GPT {

class Attention {
  public:
    Attention(int n_head,
              int d_embed,
              int chunk_size,
              int memory_mb = 2048,
              bool share_prefix = true,
              torch::Dtype dtype = torch::kFloat16,
              torch::Device device = torch::Device(torch::kCUDA));
    virtual ~Attention() = default;
    Attention(Attention const&) = delete;
    void operator=(Attention const& x) = delete;

    void remove(int seq_idx);
    void duplicate(int seq_idx, int copies);

    int add_prompt(std::vector<int>& tokens, torch::Tensor k, torch::Tensor v);
    // append tokens for all sequences, no beam changes
    void append_completions(std::vector<int>& tokens, torch::Tensor k, torch::Tensor v);
    // duplicate a seq and remove a seq in beam search
    int append_completion(int seq_idx, std::vector<int>& tokens, torch::Tensor k, torch::Tensor v);

    // partation 0: auto(chunk_seq), 1: chunk first, 2: sequence first
    torch::Tensor forward(torch::Tensor q, int partation = 0, Trace* trace = nullptr);

    void print(Chunk* root = nullptr, int level = 0);
    std::vector<Task> get_chunks();
    std::vector<std::tuple<torch::Tensor, torch::Tensor, int, int, int>> get_chunks_raw();

  private:
    Chunk* at(int seq_idx);

  private:
    int n_head_;
    int d_embed_;
    int chunk_size_;
    bool share_prefix_;
    std::shared_ptr<Chunk> root_;
    torch::TensorOptions t_options_;
    ChunkAllocator chunk_allocator_;
#ifdef USE_CUDA
    GPUKernel gpu_kernel_;
#endif
};

}