#include "token_module.hpp"
#include "kv_cache.hpp"
#include "tree_manager.hpp"

PYBIND11_MODULE(token_module, m) {
    m.doc() = "Token class module written by C++";

    py::class_<Token, std::shared_ptr<Token>>(m, "Token")
        .def(py::init<
                std::shared_ptr<KVCache>,
                std::shared_ptr<KVCache>,
                std::optional<int>,
                std::optional<int>,
                int,
                torch::Tensor,
                int,
                std::shared_ptr<Token>,
                bool,
                double,
                int,
                std::vector<int>
            >(), py::arg("base_kv_cache"), py::arg("draft_kv_cache"), py::arg("base_kv_cache_idx"), py::arg("draft_kv_cache_idx"), py::arg("input_id"), py::arg("hidden_state"), py::arg("position_id"), py::arg("parent") = py::none(), py::arg("is_fixed") = true, py::arg("score") = 1.0, py::arg("tag") = 0, py::arg("draft_attention_mask") = std::vector<int>())
        .def("__lt__", &Token::operator<)
        .def("__le__", &Token::operator<=)
        .def_readwrite("base_kv_cache", &Token::base_kv_cache)
        .def_readwrite("draft_kv_cache", &Token::draft_kv_cache)
        .def_readwrite("base_kv_cache_idx", &Token::base_kv_cache_idx)
        .def_readwrite("draft_kv_cache_idx", &Token::draft_kv_cache_idx)
        .def_readwrite("input_id", &Token::input_id)
        .def_readwrite("hidden_state", &Token::hidden_state)
        .def_readwrite("position_id", &Token::position_id)
        .def_readwrite("parent", &Token::parent)
        .def_readwrite("is_fixed", &Token::is_fixed)
        .def_readwrite("score", &Token::score)
        .def_readwrite("tag", &Token::tag)
        .def_readwrite("draft_attention_mask", &Token::draft_attention_mask);

    py::class_<KVCache, std::shared_ptr<KVCache>>(m, "KVCache")
        .def(py::init<int, int, int, int, torch::Device, torch::Dtype, int>(),
            py::arg("num_hidden_layers"),
            py::arg("num_key_value_heads"),
            py::arg("max_cache_len"),
            py::arg("head_dim"),
            py::arg("device"),
            py::arg("dtype"),
            py::arg("num_seqs") = 1)
        .def("num_allocated", &KVCache::num_allocated)
        .def("allocate", &KVCache::allocate)
        .def("free", &KVCache::free)
        .def("update", &KVCache::update,
                py::arg("key_states"),
                py::arg("value_states"),
                py::arg("layer_idx"),
                py::arg("indices"),
                py::arg("new_indices"))
        .def("get_seq_length", &KVCache::get_seq_length)
        .def_readonly("k_cache", &KVCache::k_cache)
        .def_readonly("v_cache", &KVCache::v_cache)
        .def_readonly("valid_len", &KVCache::valid_len)
        .def_readonly("free_indices", &KVCache::free_indices)
        .def_readonly("allocated_indices", &KVCache::allocated_indices);

    py::class_<TreeAttentionManager, std::shared_ptr<TreeAttentionManager>>(m, "TreeAttentionManager")
    .def(py::init<py::object, py::object, py::object, py::object, double, double, int, int, int, int>(),
         py::arg("base_model"),
         py::arg("draft_model"),
         py::arg("lm_head"),
         py::arg("draft_lm_head"),
         py::arg("temperature") = 0.0,
         py::arg("threshold") = 0.0,
         py::arg("top_draft") = 4,
         py::arg("top_node") = 32,
         py::arg("depth") = 8,
         py::arg("top_base") = 16)
    .def("is_draft_candidates_pq_empty", &TreeAttentionManager::is_draft_candidates_pq_empty)
    .def("is_base_candidates_pq_empty", &TreeAttentionManager::is_base_candidates_pq_empty)
    .def("pop_base_candidates_pq", &TreeAttentionManager::pop_base_candidates_pq)
    .def("pop_draft_candidates_pq", &TreeAttentionManager::pop_draft_candidates_pq)
    .def("push_base_candidates_pq", &TreeAttentionManager::push_base_candidates_pq,
            py::arg("token"))
    .def("push_draft_candidates_pq", &TreeAttentionManager::push_draft_candidates_pq,
            py::arg("token"))
    .def("initialize", &TreeAttentionManager::initialize,
            py::arg("input_ids_"),
            py::arg("attention_mask_"))
    .def("draft_create", &TreeAttentionManager::draft_create)
    .def("base_check", &TreeAttentionManager::base_check)
    .def("print_times", &TreeAttentionManager::print_times)
    .def_readwrite("base_model", &TreeAttentionManager::base_model)
    .def_readwrite("draft_model", &TreeAttentionManager::draft_model)
    .def_readwrite("lm_head", &TreeAttentionManager::lm_head)
    .def_readwrite("draft_lm_head", &TreeAttentionManager::draft_lm_head)
    .def_readwrite("temperature", &TreeAttentionManager::temperature)
    .def_readwrite("top_base", &TreeAttentionManager::top_base)
    .def_readwrite("top_draft", &TreeAttentionManager::top_draft)
    .def_readwrite("top_node", &TreeAttentionManager::top_node)
    .def_readwrite("depth", &TreeAttentionManager::depth)
    .def_readwrite("device", &TreeAttentionManager::device)
    .def_readwrite("base_dtype", &TreeAttentionManager::base_dtype)
    .def_readwrite("draft_dtype", &TreeAttentionManager::draft_dtype)
    .def_readwrite("lm_head_dtype", &TreeAttentionManager::lm_head_dtype)
    .def_readwrite("num_hidden_layers", &TreeAttentionManager::num_hidden_layers)
    .def_readwrite("draft_num_hidden_layers", &TreeAttentionManager::draft_num_hidden_layers)
    .def_readwrite("num_key_value_heads", &TreeAttentionManager::num_key_value_heads)
    .def_readwrite("head_dim", &TreeAttentionManager::head_dim)
    .def_readwrite("max_cache_len", &TreeAttentionManager::max_cache_len)
    .def_readwrite("base_kvcache_pool", &TreeAttentionManager::base_kvcache_pool)
    .def_readwrite("draft_kvcache_pool", &TreeAttentionManager::draft_kvcache_pool)
    .def_readwrite("base_candidates_pq", &TreeAttentionManager::base_candidates_pq)
    .def_readwrite("draft_candidates_pq", &TreeAttentionManager::draft_candidates_pq)
    .def_readwrite("output_buffer", &TreeAttentionManager::output_buffer)
    .def_readwrite("last_token", &TreeAttentionManager::last_token);
}