#ifndef KVCACHE_HPP
#define KVCACHE_HPP

#include <pybind11/pybind11.h>
#include <pybind11/stl.h>
#include <torch/extension.h>
#include <vector>

namespace py = pybind11;
using torch::indexing::Slice;

class KVCache {
public:
    KVCache(int num_hidden_layers,
            int num_key_value_heads,
            int max_cache_len,
            int head_dim,
            torch::Device device,
            torch::Dtype dtype,
            int num_seqs = 1);

    int num_allocated() const;

    std::vector<int> allocate(int sequence_length);

    void free(const std::vector<int>& indices, const std::vector<int>& unfree_indices);

    py::tuple update(torch::Tensor key_states,
                     torch::Tensor value_states,
                     int layer_idx,
                     const std::vector<int>& indices,
                     const std::vector<int>& new_indices);

    int get_seq_length() const;

    int num_hidden_layers, num_key_value_heads, max_cache_len, head_dim, num_seqs;
    torch::Device device;
    torch::Dtype dtype;
    torch::Tensor k_cache;
    torch::Tensor v_cache;
    int valid_len;
    std::vector<int> free_indices;
    std::vector<int> allocated_indices;
};

#endif // KVCACHE_HPP