#include "police/var_order_addtree.hpp"

#if not(POLICE_VERITAS)

#include "macros.hpp"

namespace police {
AddTreeOccurrencesOrder::AddTreeOccurrencesOrder(
    [[maybe_unused]] size_t num_vars,
    [[maybe_unused]] const addtree_type& add_tree,
    [[maybe_unused]] const vector<size_t>& feature_to_var)
{
    POLICE_MISSING_DEPENDENCY("Veritas");
}

} // namespace police

#else

#include "addtree_policy.hpp"
#include "arguments.hpp"
#include "model.hpp"
#include "option.hpp"
#include "storage/vector.hpp"

#include <limits>
#include <veritas/addtree.hpp>
#include <veritas/basics.hpp>

namespace police {

AddTreeOccurrencesOrder::AddTreeOccurrencesOrder(
    size_t num_vars,
    const addtree_type& add_tree,
    const vector<size_t>& feature_to_var)
{
    vector<size_t> occurences(num_vars, std::numeric_limits<size_t>::max());
    std::function<void(const tree_type&, veritas::NodeId)> collect_vars;
    collect_vars = [&](const tree_type& tree, const auto node_id) {
        if (tree.is_leaf(node_id)) {
            return;
        }
        const auto& split = tree.get_split(node_id);
        assert(split.feat_id < static_cast<int>(feature_to_var.size()));
        assert(feature_to_var[split.feat_id] < occurences.size());
        --occurences[feature_to_var[split.feat_id]];
        collect_vars(tree, tree.left(node_id));
        collect_vars(tree, tree.right(node_id));
    };
    for (size_t i = 0; i < add_tree.size(); ++i) {
        collect_vars(add_tree[i], add_tree[i].root());
    }
    set_variable_ranks(std::move(occurences));
}

namespace {

PointerOption<VariableOrderChooser>
    _option("tree_occurrences", [](const Arguments& args) {
        if (!args.has_addtree_policy()) {
            POLICE_INTERNAL_ERROR(
                "tree_occurences variable ordering can only "
                "be used with tree-ensemble policies");
        }
        const auto& policy = args.get_addtree_policy();
        const auto& model = args.get_model();
        return std::make_shared<AddTreeOccurrencesOrder>(
            model.variables.size(),
            *policy.get_addtree(),
            policy.get_input());
    });

} // namespace

} // namespace police

#endif
