#ifndef TREE_ATTENTION_MANAGER_HPP
#define TREE_ATTENTION_MANAGER_HPP

#include <torch/torch.h>
#include <torch/python.h>
#include <pybind11/pybind11.h>
#include <pybind11/stl.h>
#include <pybind11/numpy.h>
#include <queue>
#include <vector>
#include <memory>
#include <unordered_map>
#include <unordered_set>
#include <stdexcept>
#include <chrono>
#include <algorithm>

#include "kv_cache.hpp"
#include "token_module.hpp"

namespace py = pybind11;
using namespace std::chrono;


inline uint64_t get_ticks() {
    return std::chrono::high_resolution_clock::now().time_since_epoch().count();
}

#define TIMEIT(x) \
    uint64_t time_start_##x = get_ticks(); \
    size_t index_##x = __COUNTER__;        \
    time_events[index_##x].first = #x
#define TIMEIT_END(x) \
    time_events[index_##x].second += get_ticks() - time_start_##x
#define TIMEIT_INIT \
    time_events.clear(); \
    time_events.resize(100)
#define TIMEIT_FINI \
    for (auto i = 0; i < time_events.size(); i++) { \
        if (time_events[i].first == NULL) continue; \
        times[time_events[i].first] = time_events[i].second; \
    } \
    time_events.clear()

struct Candidate {
    double score;
    int position_id;
    std::shared_ptr<Token> token;
};

struct CandidateComparator {
    bool operator()(const Candidate& a, const Candidate& b) const {
        if (a.score != b.score)
            return a.score < b.score;
        return a.position_id > b.position_id;
    }
};

class TreeAttentionManager {
public:
    TreeAttentionManager(py::object base_model,
                         py::object draft_model,
                         py::object lm_head,
                         py::object draft_lm_head,
                         double temperature = 0.0,
                         double threshold = 0.0,
                         int top_draft = 4,
                         int top_node = 32,
                         int depth = 8,
                         int top_base = 16);

    bool is_draft_candidates_pq_empty();
    bool is_base_candidates_pq_empty();
    std::shared_ptr<Token> pop_base_candidates_pq();
    std::shared_ptr<Token> pop_draft_candidates_pq();
    void push_base_candidates_pq(std::shared_ptr<Token> token);
    void push_draft_candidates_pq(std::shared_ptr<Token> token);

    void initialize(torch::Tensor input_ids_,
                    torch::Tensor attention_mask_);

    void draft_create();

    py::list base_check();

    std::shared_ptr<Token> last_token;

    py::object base_model;
    py::object draft_model;
    py::object lm_head;
    py::object draft_lm_head;
    double temperature;
    double threshold;
    int top_base;
    int top_draft;
    int top_node;
    int depth;

    torch::Device device;
    torch::Dtype base_dtype;
    torch::Dtype draft_dtype;
    torch::Dtype lm_head_dtype;

    int num_hidden_layers;
    int draft_num_hidden_layers;
    int num_key_value_heads;
    int head_dim;
    int max_cache_len;

    std::shared_ptr<KVCache> base_kvcache_pool;
    std::shared_ptr<KVCache> draft_kvcache_pool;

    std::priority_queue<Candidate, std::vector<Candidate>, CandidateComparator> base_candidates_pq;
    std::priority_queue<Candidate, std::vector<Candidate>, CandidateComparator> draft_candidates_pq;

    std::vector<std::shared_ptr<Token>> input_tokens;

    std::vector<int> output_buffer;
    
    std::vector<std::pair<const char*, uint64_t>> time_events;
    std::map<std::string, uint64_t> times;

    // print times
    void print_times() {
        TIMEIT_FINI;
        for (const auto& event : times) {
            std::cout << event.first << ": " << "\t" << event.second / 1000000 << "ms" << std::endl;
        }
    }
};

#endif // TREE_ATTENTION_MANAGER_HPP