#include "tree_manager.hpp"

TreeAttentionManager::TreeAttentionManager(
                        py::object base_model,
                        py::object draft_model,
                        py::object lm_head,
                        py::object draft_lm_head,
                        double temperature,
                        double threshold,
                        int top_draft,
                        int top_node,
                        int depth,
                        int top_base)
    : base_model(base_model), draft_model(draft_model), lm_head(lm_head), draft_lm_head(draft_lm_head),
        temperature(temperature), threshold(threshold), top_draft(top_draft), top_node(top_node), depth(depth), top_base(top_base),
        device(torch::kCPU), base_dtype(torch::kFloat32), draft_dtype(torch::kFloat32), lm_head_dtype(torch::kFloat32)
{
    device = torch::python::detail::py_object_to_device(base_model.attr("device"));
    base_dtype = torch::python::detail::py_object_to_dtype(base_model.attr("dtype"));
    draft_dtype = torch::python::detail::py_object_to_dtype(draft_model.attr("dtype"));
    lm_head_dtype = torch::python::detail::py_object_to_dtype(lm_head.attr("weight").attr("dtype"));

    py::object config = base_model.attr("config");
    py::object draft_config = draft_model.attr("config");
    num_hidden_layers = config.attr("num_hidden_layers").cast<int>();
    draft_num_hidden_layers = draft_config.attr("draft_num_hidden_layers").cast<int>();
    num_key_value_heads = config.attr("num_key_value_heads").cast<int>();
    head_dim = config.attr("head_dim").cast<int>();
    max_cache_len = 8192;

    base_kvcache_pool = std::make_shared<KVCache>(
        num_hidden_layers,
        num_key_value_heads,
        max_cache_len,
        head_dim,
        device,
        base_dtype
    );

    draft_kvcache_pool = std::make_shared<KVCache>(
        draft_num_hidden_layers,
        num_key_value_heads,
        max_cache_len,
        head_dim,
        device,
        draft_dtype
    );
}

bool TreeAttentionManager::is_draft_candidates_pq_empty() {
    return draft_candidates_pq.empty();
}

bool TreeAttentionManager::is_base_candidates_pq_empty() {
    return base_candidates_pq.empty();
}

std::shared_ptr<Token> TreeAttentionManager::pop_base_candidates_pq() {
    if (base_candidates_pq.empty()) {
        return nullptr;
    }
    auto cand = base_candidates_pq.top();
    base_candidates_pq.pop();
    return cand.token;
}

std::shared_ptr<Token> TreeAttentionManager::pop_draft_candidates_pq() {
    if (draft_candidates_pq.empty()) {
        return nullptr;
    }
    auto cand = draft_candidates_pq.top();
    draft_candidates_pq.pop();
    return cand.token;
}

void TreeAttentionManager::push_base_candidates_pq(std::shared_ptr<Token> token) {
    base_candidates_pq.push(Candidate{token->score, token->position_id, token});
}

void TreeAttentionManager::push_draft_candidates_pq(std::shared_ptr<Token> token) {
    draft_candidates_pq.push(Candidate{token->score, token->position_id, token});
}

void TreeAttentionManager::initialize(
    torch::Tensor input_ids_,
    torch::Tensor attention_mask_)
{
    TIMEIT_INIT;
    if (input_ids_.size(0) != 1) {
        throw std::runtime_error("Only one input sequence is supported");
    }
    torch::Tensor input_ids = input_ids_.to(device);
    torch::Tensor attention_mask = attention_mask_.to(device);
    auto seq_length = input_ids.size(1);
    // position_ids: shape (1, seq_length)
    torch::Tensor position_ids = torch::arange(seq_length, input_ids.options()).unsqueeze(0);

    std::vector<int> base_kv_cache_indices = base_kvcache_pool->allocate(seq_length);

    py::dict kwargs;
    kwargs["input_ids"] = input_ids;
    kwargs["attention_mask"] = attention_mask;
    kwargs["position_ids"] = position_ids;
    kwargs["past_key_values"] = base_kvcache_pool;
    kwargs["past_key_value_indices"] = base_kv_cache_indices;
    kwargs["new_past_key_value_indices"] = std::vector<int>{};
    kwargs["use_cache"] = true;
    py::object result = base_model(**kwargs);
    torch::Tensor hidden_states = result.attr("last_hidden_state").cast<torch::Tensor>();

    std::shared_ptr<Token> parent_token = nullptr;

    auto seq_len = input_ids.size(1);
    for (int64_t i = 0; i < seq_len; i++) {
        std::optional<int> base_idx = base_kv_cache_indices[i];
        torch::Tensor token_hidden_state;
        if (i > 0) {
            token_hidden_state = hidden_states.index({torch::indexing::Slice(), i - 1, torch::indexing::Slice()});
        } else {
            int hidden_size = base_model.attr("config").attr("hidden_size").cast<int>();
            token_hidden_state = torch::zeros({1, hidden_size},
                torch::TensorOptions().device(device)
                                        .dtype(base_dtype));
        }
        auto new_token = std::make_shared<Token>(
            base_kvcache_pool,
            draft_kvcache_pool,
            base_idx,
            std::nullopt, // draft_kv_cache_idx = None
            input_ids.index({0, i}).item<int>(),
            token_hidden_state,
            i,  // position_id
            parent_token,
            true,  // is_fixed
            0.0,    // score
            -1, // tag
            // std::vector<int>(i + 1, 1)  // draft_attention_mask
            std::vector<int>() // draft_attention_mask
        );
        parent_token = new_token;
        input_tokens.push_back(new_token);
    }

    torch::Tensor last_hidden = hidden_states.index({torch::indexing::Slice(), -1, torch::indexing::Slice()});
    last_hidden = last_hidden.to(lm_head_dtype);
    py::object lm_result = lm_head(last_hidden);
    torch::Tensor logits = lm_result.cast<torch::Tensor>();

    torch::Tensor output_ids;
    if (temperature > 0) {
        logits = logits.div(temperature);
        torch::Tensor probs = torch::softmax(logits, -1);
        output_ids = torch::multinomial(probs, 1).flatten();
    } else {
        output_ids = torch::argmax(logits, -1).flatten();
    }

    last_token = std::make_shared<Token>(
        base_kvcache_pool,
        draft_kvcache_pool,
        std::nullopt,
        std::nullopt,
        output_ids[0].item<int>(),
        hidden_states.index({torch::indexing::Slice(), -1, torch::indexing::Slice()}),
        seq_len, // position_id = input_ids.size(1)
        parent_token,
        false, // is_fixed
        0.0,    // score
        -1, // tag
        // std::vector<int>(seq_len + 1, 1) // draft_attention_mask
        std::vector<int>() // draft_attention_mask
    );
    base_candidates_pq.push(Candidate{last_token->score, last_token->position_id, last_token});
    draft_candidates_pq.push(Candidate{last_token->score, last_token->position_id, last_token});
    input_tokens.push_back(last_token);

    output_buffer.push_back(last_token->input_id);
}

void TreeAttentionManager::draft_create() {
    TIMEIT(draft_create);
    std::vector<int> kv_cache_indices = draft_kvcache_pool->allocated_indices;
    std::vector<int> del_kv_cache_indices;
    auto valid_len = static_cast<int>(draft_kvcache_pool->num_allocated());
    int tag_id = 0;

    auto input_attention_masks = torch::full(
        {0, max_cache_len},
        -1000.0,
        torch::TensorOptions().device(device).dtype(draft_dtype)
    );
    auto input_features = torch::empty(
        {0, head_dim},
        torch::TensorOptions().device(device).dtype(draft_dtype)
    );

    for (int depth_idx = 0; depth_idx < depth; ++depth_idx) {
        // std::cout << "depth_idx: " << depth_idx << std::endl;
        // Pop top draft candidates
        std::vector<std::shared_ptr<Token>> draft_inputs;
        TIMEIT(draft_create_pq);
        int nc = (depth_idx == 0 ? 1 : top_draft);
        while (!draft_candidates_pq.empty() && draft_inputs.size() < static_cast<size_t>(nc)) {
            auto cand = draft_candidates_pq.top();
            draft_candidates_pq.pop();
            draft_inputs.push_back(cand.token);
        }
        TIMEIT_END(draft_create_pq);

        nc = static_cast<int>(draft_inputs.size());
        if (nc == 0) break;

        if (threshold > 0.0) {
            // Check if the sum of scores is below the threshold
            auto score_sum = 0.0f;
            for (auto& tok : draft_inputs) {
                score_sum += exp(tok->score);
            }
            if (score_sum < threshold) {
                break;
            }
        }

        TIMEIT(draft_create_preprocess);
        if (depth_idx == 0) {
            auto tok = draft_inputs[0];
            auto parent = tok->parent;
            std::vector<torch::Tensor> features;
            features.insert(features.begin(), tok->hidden_state.to(draft_dtype));
            while (parent) {
                if (parent->draft_kv_cache_idx.has_value()) {
                    break;
                }
                draft_inputs.insert(draft_inputs.begin(), parent);
                features.insert(features.begin(), parent->hidden_state.to(draft_dtype));
                parent = parent->parent;
            }

            if (draft_inputs.size() != features.size()) {
                std::cout << "draft_inputs.size(): " << draft_inputs.size() << std::endl;
                std::cout << "features.size(): " << features.size() << std::endl;
                throw std::runtime_error("draft_create: draft_inputs and features size mismatch");
            }

            input_attention_masks = torch::full(
                {static_cast<int>(draft_inputs.size()), max_cache_len},
                -1000.0,
                torch::TensorOptions().device(device).dtype(draft_dtype)
            );
            input_attention_masks.triu_(valid_len + 1);
            input_features = torch::cat(features, 0);

            for (int i = 0; i < static_cast<int>(draft_inputs.size()); i++) {
                auto cur = draft_inputs[i];
                cur->tag = tag_id++;
            }
        } else {
            std::vector<int> parent_tags;
            for (int i = 0; i < static_cast<int>(draft_inputs.size()); i++) {
                auto cur = draft_inputs[i];
                cur->tag = tag_id++;
                parent_tags.push_back(cur->parent->tag);
            }
            auto parent_tags_t = torch::tensor(parent_tags, torch::TensorOptions().device(torch::kCPU).dtype(torch::kInt64));
            auto new_attention_masks = input_attention_masks.index({parent_tags_t.to(torch::kLong), torch::indexing::Slice()});
            new_attention_masks.diagonal(valid_len + input_attention_masks.size(0), 0, 1).fill_(0.0);
            input_attention_masks = torch::cat({input_attention_masks, new_attention_masks}, 0);
        }
        TIMEIT_END(draft_create_preprocess);

        if (input_attention_masks.size(0) != tag_id) {
            std::cout << "input_attention_masks.size(0): " << input_attention_masks.size(0) << std::endl;
            std::cout << "tag_id: " << tag_id << std::endl;
            throw std::runtime_error("draft_create: input_attention_masks size mismatch");
        }

        TIMEIT(draft_create_make_inputs);
        // Prepare model inputs
        std::vector<int64_t> ids, tags, pos;
        ids.reserve(static_cast<int>(draft_inputs.size()));
        pos.reserve(static_cast<int>(draft_inputs.size()));
        tags.reserve(static_cast<int>(draft_inputs.size()));
        std::vector<torch::Tensor> features_list;

        for (auto& tok : draft_inputs) {
            ids.push_back(tok->input_id);
            pos.push_back(tok->position_id);
            tags.push_back(tok->tag);
            features_list.push_back(tok->hidden_state.to(draft_dtype));
        }
        auto input_ids = torch::tensor(ids, torch::TensorOptions().device(device).dtype(torch::kInt64)).unsqueeze(0);
        auto tags_t = torch::tensor(tags, torch::TensorOptions().device(torch::kCPU).dtype(torch::kInt64));
        auto input_feats = torch::cat(features_list, 0).unsqueeze(0);
        auto attn_masks = input_attention_masks.index({tags_t, torch::indexing::Slice()});
        auto position_ids = torch::tensor(pos, torch::TensorOptions().device(device).dtype(torch::kInt64)).unsqueeze(0);

        auto new_indices = draft_kvcache_pool->allocate(draft_inputs.size());

        for (int i = 0; i < static_cast<int>(draft_inputs.size()); i++) {
            draft_inputs[i]->draft_kv_cache_idx = new_indices[i];
        }

        if (depth_idx != 0) {
            del_kv_cache_indices.insert(
                del_kv_cache_indices.end(),
                new_indices.begin(),
                new_indices.end()
            );
        }
        TIMEIT_END(draft_create_make_inputs);

        // Call draft model via pybind11
        TIMEIT(draft_create_model);
        py::dict kwargs;
        kwargs["input_ids"] = input_ids;
        kwargs["input_features"] = input_feats;
        kwargs["attention_mask"] = attn_masks.unsqueeze(0).unsqueeze(0);
        // kwargs["attention_mask"] = attn_mask;
        kwargs["position_ids"] = position_ids;
        kwargs["past_key_values"] = draft_kvcache_pool;
        kwargs["past_key_value_indices"] = draft_kvcache_pool->allocated_indices;
        kwargs["new_past_key_value_indices"] = std::vector<int>{};
        kwargs["use_cache"] = true;
        kwargs["shift_tokens"] = false;
        kwargs["cut_last_token"] = false;
        auto result = draft_model(**kwargs);
        auto hidden = result.attr("last_hidden_state").cast<torch::Tensor>();
        hidden = hidden.slice(1, hidden.size(1) - nc, hidden.size(1));
        TIMEIT_END(draft_create_model);

        // Compute logits and combine with previous scores
        TIMEIT(draft_create_logits);
        auto logits = torch::log_softmax(
            draft_lm_head(hidden.to(lm_head_dtype)).cast<torch::Tensor>(),
            /*dim=*/-1
        );
        TIMEIT_END(draft_create_logits);

        TIMEIT(draft_create_score);
        int64_t vocab_size = logits.size(-1);
        std::vector<float> scores;
        scores.reserve(nc);
        for (int i = static_cast<int>(draft_inputs.size()) - nc; i < static_cast<int>(draft_inputs.size()); i++)
            scores.push_back(draft_inputs[i]->score);
        auto prev_scores = torch::tensor(scores, torch::TensorOptions().device(device).dtype(torch::kHalf));
        logits = logits + prev_scores.unsqueeze(-1);
        TIMEIT_END(draft_create_score);

        // Flatten and select top candidates
        TIMEIT(draft_create_sort2);
        auto flat = logits.flatten();
        auto top = torch::topk(flat, top_base);
        auto top_probs = std::get<0>(top);
        auto top_idxs = std::get<1>(top);
        TIMEIT_END(draft_create_sort2);

        TIMEIT(draft_create_extra);
        auto idxs = top_idxs / vocab_size;
        auto ids2 = top_idxs % vocab_size;
        top_probs = top_probs.to(torch::kFloat32).cpu();
        idxs = idxs.cpu();
        ids2 = ids2.cpu();
        std::vector<torch::Tensor> splitted_hidden = torch::split(hidden.index(
            {0, idxs.to(torch::kLong), torch::indexing::Slice()}
        ), 1, 0);

        // Push new candidates
#pragma omp parallel for
        for (int i = 0; i < top_probs.size(0); ++i) {
            float prob = top_probs[i].item<float>();
            int idx2 = idxs[i].item<int>();
            int token_id = ids2[i].item<int>();
            auto parent_tok = draft_inputs[static_cast<int>(draft_inputs.size()) - nc + idx2];
            auto new_tok = std::make_shared<Token>(
                base_kvcache_pool,
                draft_kvcache_pool,
                std::nullopt,
                std::nullopt,
                token_id,
                splitted_hidden[i],
                parent_tok->position_id + 1,
                parent_tok,
                /*is_fixed=*/false,
                prob,
                -1,
                std::vector<int>()
            );
#pragma omp critical
            {
                base_candidates_pq.push(Candidate{new_tok->score, new_tok->position_id, new_tok});
                draft_candidates_pq.push(Candidate{new_tok->score, new_tok->position_id, new_tok});
            }
        }
        TIMEIT_END(draft_create_extra);
    }
    // Clean up draft candidates
    draft_kvcache_pool->free(del_kv_cache_indices, std::vector<int>());
    TIMEIT_END(draft_create);
}

py::list TreeAttentionManager::base_check() {
    std::vector<std::shared_ptr<Token>> base_inputs;
    TIMEIT(base_check);
    TIMEIT(base_check_pq);
    while (!base_candidates_pq.empty() && base_inputs.size() < static_cast<size_t>(top_base)) {
        auto cand = base_candidates_pq.top();
        base_candidates_pq.pop();
        cand.token->score = torch::rand({1}, torch::TensorOptions().device(torch::kCPU).dtype(torch::kFloat32)).item<float>();
        base_inputs.push_back(cand.token);
    }
    TIMEIT_END(base_check_pq);
    int nc = static_cast<int>(base_inputs.size());

    // Sort base_inputs by position_id
    TIMEIT(base_check_sort_inputs);
    std::sort(base_inputs.begin(), base_inputs.end(),
        [](auto& a, auto& b) { return a->position_id != b->position_id ? a->position_id < b->position_id : a->score > b->score; });
    TIMEIT_END(base_check_sort_inputs);

    // Gather all ancestors
    TIMEIT(base_check_anc);
    std::unordered_set<std::shared_ptr<Token>> all_anc;
    for (auto& tok : base_inputs) {
        auto p = tok->parent;
        while (p) {
            // if already in all_anc, break
            if (all_anc.find(p) != all_anc.end()) {
                break;
            }
            all_anc.insert(p);
            p = p->parent;
        }
    }
    TIMEIT_END(base_check_anc);

    TIMEIT(base_check_filter);
    std::vector<std::shared_ptr<Token>> prev_parents;
    prev_parents.reserve(all_anc.size());
    for (auto& tok : all_anc) {
        if (std::find(base_inputs.begin(), base_inputs.end(), tok) == base_inputs.end())
            prev_parents.push_back(tok);
    }
    TIMEIT_END(base_check_filter);

    TIMEIT(base_check_sort);
    std::sort(
        prev_parents.begin(),
        prev_parents.end(),
        [](auto& a, auto& b){ return a->position_id < b->position_id; }
    );
    TIMEIT_END(base_check_sort);

    std::vector<int64_t> ids, pos;
    ids.reserve(static_cast<int>(base_inputs.size()));
    pos.reserve(static_cast<int>(base_inputs.size()));

    for (auto& tok : base_inputs) {
        ids.push_back(tok->input_id);
        pos.push_back(tok->position_id);
    }
    auto input_ids = torch::tensor(ids, torch::TensorOptions().dtype(torch::kInt64).device(device)).unsqueeze(0);
    auto position_ids = torch::tensor(pos, torch::TensorOptions().dtype(torch::kInt64).device(device)).unsqueeze(0);

    TIMEIT(base_check_kv);
    std::vector<int64_t> prev_indices;
    for (auto& t : prev_parents) {
        if (t->base_kv_cache_idx)
            prev_indices.push_back(*t->base_kv_cache_idx);
        else {
            // std::cout << t->input_id << " " << t->position_id << std::endl;
            throw std::runtime_error("prev_parent_indices not found");
        }
    }
    auto new_indices = base_kvcache_pool->allocate(nc);
    std::vector<int64_t> past_kv_indices = prev_indices;
    past_kv_indices.insert(past_kv_indices.end(), new_indices.begin(), new_indices.end());

    for (int i = 0; i < nc; i++) {
        base_inputs[i]->base_kv_cache_idx = new_indices[i];
    }
    TIMEIT_END(base_check_kv);

    TIMEIT(base_check_attn);
    int total_len = static_cast<int>(prev_parents.size()) + nc;
    auto attn_mask = torch::full(
        {1, 1, nc, total_len},
        -1000.0,
        torch::TensorOptions().dtype(base_dtype).device(torch::kCPU)
    );
#pragma omp parallel for
    for (int i = 0; i < nc; ++i) {
        auto cur = base_inputs[i];
        while (cur) {
            if (!cur->base_kv_cache_idx) {
                throw std::runtime_error("base_kv_cache_idx not found");
            }
            auto it = std::find(past_kv_indices.begin(), past_kv_indices.end(), *cur->base_kv_cache_idx);
            int idx = static_cast<int>(std::distance(past_kv_indices.begin(), it));
            if (cur->is_fixed) {
                std::vector<torch::indexing::TensorIndex> idxs = {
                    0, 0, i,
                    torch::indexing::Slice(0, idx + 1)
                };
                attn_mask.index_put_(idxs, 0.0);
                break;
            }
            attn_mask[0][0][i][idx] = 0.0;
            cur = cur->parent;
        }
    }
    attn_mask = attn_mask.to(device);
    TIMEIT_END(base_check_attn);

    TIMEIT(base_check_model);
    py::object outputs = base_model(
        py::arg("input_ids")=input_ids,
        py::arg("attention_mask")=attn_mask,
        py::arg("position_ids")=position_ids,
        py::arg("past_key_values")=base_kvcache_pool,
        py::arg("past_key_value_indices")=past_kv_indices,
        py::arg("new_past_key_value_indices")=py::list(),
        py::arg("use_cache")=true,
        py::arg("shift_tokens")=false,
        py::arg("cut_last_token")=false
    );
    auto hidden = outputs.attr("last_hidden_state").cast<torch::Tensor>();
    TIMEIT_END(base_check_model);

    TIMEIT(base_check_logits);
    torch::Tensor logits = lm_head(hidden.to(lm_head_dtype)).cast<torch::Tensor>();
    TIMEIT_END(base_check_logits);

    TIMEIT(base_check_update);
    int last_token_id = -1;
    torch::Tensor last_hidden_state;
    std::vector<int> del_kv_cache_indices;
    std::vector<int> res_kv_cache_indices;
    if (temperature > 0) {
        logits = logits / temperature;
        auto probs = torch::softmax(logits, -1).cpu();

        auto last_idx = std::distance(base_inputs.begin(), std::find(base_inputs.begin(), base_inputs.end(), last_token));
        if (last_idx != 0) {
            throw std::runtime_error("last_idx != 0");
        }
        auto last_prob = probs.index({0, last_idx, torch::indexing::Slice()});
        last_hidden_state = hidden.select(1, last_idx);
        res_kv_cache_indices.push_back(*base_inputs[last_idx]->base_kv_cache_idx);

        for (int idx = 1; idx < nc; idx++)  {
            auto& cand = base_inputs[idx];
            if (cand->parent == last_token) {
                float px = last_prob.index({cand->input_id}).item<float>();
                float qx = 1.0f;
                float acp = px / qx;
                // uniform sampling U(0, 1)
                auto u = torch::rand({1}, torch::TensorOptions().device(torch::kCPU).dtype(torch::kFloat32)).item<float>();
                if (u <= acp) {
                    cand->hidden_state = last_hidden_state;

                    last_prob = probs.index({0, idx, torch::indexing::Slice()});
                    last_hidden_state = hidden.select(1, idx);
                    output_buffer.push_back(cand->input_id);
                    last_token = cand;

                    if (cand->draft_kv_cache_idx.has_value()) {
                        // cand->draft_kv_cache->free({*cand->draft_kv_cache_idx});
                        cand->draft_kv_cache_idx = std::nullopt;
                    }
                    cand->is_fixed = true;
                } else {
                    last_prob.index_put_({cand->input_id}, 0.0f);
                    last_prob = last_prob / (1.0f - px);
                }
            }
            if (!cand->is_fixed) {
                // cand->base_kv_cache->free({*cand->base_kv_cache_idx});
                del_kv_cache_indices.push_back(*cand->base_kv_cache_idx);
                cand->base_kv_cache_idx = std::nullopt;
            } else {
                res_kv_cache_indices.push_back(*cand->base_kv_cache_idx);
            }
        }
        last_token_id = torch::multinomial(last_prob, 1).item<int>();
        output_buffer.push_back(last_token_id);
    } else {
        auto next_ids = torch::argmax(logits, -1);
        next_ids = next_ids.flatten().cpu();

        int last_idx = std::distance(base_inputs.begin(), std::find(base_inputs.begin(), base_inputs.end(), last_token));
        if (last_idx != 0) {
            throw std::runtime_error("last_idx != 0");
        }
        last_token_id = next_ids[last_idx].item<int>();
        last_hidden_state = hidden.select(1, last_idx);
        res_kv_cache_indices.push_back(*base_inputs[last_idx]->base_kv_cache_idx);

        output_buffer.push_back(last_token_id);
        for (int idx = 1; idx < nc; idx++) {
            int output_id = next_ids[idx].item<int>();
            auto& cand = base_inputs[idx];
            if (last_token_id == cand->input_id && cand->parent == last_token) {
                cand->hidden_state = last_hidden_state;

                last_token_id = output_id;
                last_hidden_state = hidden.select(1, idx);
                output_buffer.push_back(last_token_id);
                last_token = cand;

                if (cand->draft_kv_cache_idx.has_value()) {
                    // cand->draft_kv_cache->free({*cand->draft_kv_cache_idx});
                    cand->draft_kv_cache_idx = std::nullopt;
                }
                cand->is_fixed = true;
            }
            if (!cand->is_fixed) {
                // cand->base_kv_cache->free({*cand->base_kv_cache_idx});
                del_kv_cache_indices.push_back(*cand->base_kv_cache_idx);
                cand->base_kv_cache_idx = std::nullopt;
            } else {
                res_kv_cache_indices.push_back(*cand->base_kv_cache_idx);
            }
        }
    }
    TIMEIT_END(base_check_update);

    if (static_cast<int>(del_kv_cache_indices.size() + res_kv_cache_indices.size()) != nc) {
        throw std::runtime_error("del_kv_cache_indices.size() + res_kv_cache_indices.size() != nc");
    }
    base_kvcache_pool->free(del_kv_cache_indices, res_kv_cache_indices);

    TIMEIT(base_check_clear_pq);
    while (!base_candidates_pq.empty()) base_candidates_pq.pop();
    while (!draft_candidates_pq.empty()) draft_candidates_pq.pop();
    TIMEIT_END(base_check_clear_pq);

    TIMEIT(base_check_extra);
    auto new_last = std::make_shared<Token>(
        base_kvcache_pool, draft_kvcache_pool, std::nullopt, std::nullopt,
        last_token_id, last_hidden_state, last_token->position_id + 1, last_token, false, 0.0f,
        -1, // tag
        std::vector<int>() // draft_attention_mask
    );
    last_token = new_last;
    base_candidates_pq.push(Candidate{new_last->score, new_last->position_id, new_last});
    draft_candidates_pq.push(Candidate{new_last->score, new_last->position_id, new_last});

    py::list output_ids;
    for (auto id : output_buffer)
        output_ids.append(id);

    output_buffer.clear();
    TIMEIT_END(base_check_extra);
    TIMEIT_END(base_check);

    return output_ids;
}
