#ifndef TOKEN_HPP
#define TOKEN_HPP

#include <pybind11/pybind11.h>
#include <pybind11/pytypes.h>
#include <memory>
#include <optional>
#include <torch/torch.h>

#include "kv_cache.hpp"

namespace py = pybind11;

class Token {
public:
    std::shared_ptr<KVCache> base_kv_cache;
    std::shared_ptr<KVCache> draft_kv_cache;

    std::optional<int> base_kv_cache_idx;
    std::optional<int> draft_kv_cache_idx;

    int input_id;
    torch::Tensor hidden_state;
    int position_id;

    std::shared_ptr<Token> parent;

    bool is_fixed;
    float score;

    int tag;

    std::vector<int> draft_attention_mask;

    Token(
        std::shared_ptr<KVCache> base_kv_cache,
        std::shared_ptr<KVCache> draft_kv_cache,
        std::optional<int> base_kv_cache_idx,
        std::optional<int> draft_kv_cache_idx,
        int input_id,
        torch::Tensor hidden_state,
        int position_id,
        std::shared_ptr<Token> parent,
        bool is_fixed,
        float score,
        int tag,
        std::vector<int> draft_attention_mask
    ) : base_kv_cache(base_kv_cache),
        draft_kv_cache(draft_kv_cache),
        base_kv_cache_idx(base_kv_cache_idx),
        draft_kv_cache_idx(draft_kv_cache_idx),
        input_id(input_id),
        hidden_state(hidden_state),
        position_id(position_id),
        parent(parent),
        is_fixed(is_fixed),
        score(score),
        tag(tag),
        draft_attention_mask(draft_attention_mask)
    {}

    ~Token() {
    }

    bool operator<(const Token& other) const {
        return this->score < other.score;
    }

    bool operator<=(const Token& other) const {
        return this->score <= other.score;
    }
};

#endif // TOKEN_HPP