#include "police/addtree_policy.hpp"
#include "police/addtree.hpp"
#include "police/base_types.hpp"
#include "police/utils/io.hpp"

#if POLICE_VERITAS

#include <algorithm>
#include <functional>
#include <veritas/addtree.hpp>
#include <veritas/basics.hpp>

namespace police::details {

AddTreeEvaluator::AddTreeEvaluator(std::shared_ptr<AddTree> addtree)
    : addtree_(std::move(addtree))
{
}

vector<real_t>
AddTreeEvaluator::operator()(const vector<real_t>& features) const
{
    vector<DecisionTreeValueType> in;
    in.insert(in.end(), features.begin(), features.end());
    vector<DecisionTreeValueType> out(addtree_->num_leaf_values());
    veritas::data<DecisionTreeValueType> out_data(out);
    addtree_->eval(veritas::data<DecisionTreeValueType>(in), out_data);
    vector<real_t> real_out(out.size());
    for (int i = out.size() - 1; i >= 0; --i) {
        real_out[i] = out[i];
    }
    return real_out;
}

} // namespace police::details

#else

namespace police::details {

AddTreeEvaluator::AddTreeEvaluator(std::shared_ptr<addtree_type>)
{
    POLICE_MISSING_DEPENDENCY("Veritas");
}

vector<real_t> AddTreeEvaluator::operator()(const vector<real_t>&) const
{
    POLICE_MISSING_DEPENDENCY("Veritas");
}

} // namespace police::details

#endif

namespace police {

AddTreePolicy::AddTreePolicy(
    std::shared_ptr<AddTree> addtree,
    vector<size_t> input_vars,
    vector<size_t> output_actions,
    size_t num_actions)
    : VBasedPolicy<details::AddTreeEvaluator>(
          details::AddTreeEvaluator(std::move(addtree)),
          std::move(input_vars),
          std::move(output_actions),
          num_actions)
{
#if POLICE_VERITAS
    vector<size_t> feature_ids;
    std::function<void(const DecisionTree&, veritas::NodeId)> collect_features;
    collect_features = [&](const DecisionTree& tree, veritas::NodeId node) {
        if (!tree.is_leaf(node)) {
            const auto& split = tree.get_split(node);
            feature_ids.push_back(split.feat_id);
            collect_features(tree, tree.left(node));
            collect_features(tree, tree.right(node));
        }
    };
    auto collect = [&](const DecisionTree& tree) {
        collect_features(tree, tree.root());
        std::sort(feature_ids.begin(), feature_ids.end());
        feature_ids.erase(
            std::unique(feature_ids.begin(), feature_ids.end()),
            feature_ids.end());
    };
    const auto& ensemble = this->get_addtree();
    // std::cout << "input_vars=" << print_sequence(this->get_input())
    //           << std::endl;
    // std::cout << "number of trees: " << ensemble->size() << std::endl;
    // std::cout << "max feature id: " << ensemble->get_maximum_feat_id()
    //           << std::endl;
    for (auto tree = ensemble->begin(); tree != ensemble->end(); ++tree) {
        collect(*tree);
        // std::cout << "tree: size=" << tree->tree_size(tree->root())
        //           << " max_depth=" << tree->max_depth(tree->root())
        //           << " features=" << print_sequence(feature_ids) <<
        //           std::endl;
        assert(feature_ids.back() < this->get_input().size());
        feature_ids.clear();
    }
    assert(
        static_cast<size_t>(get_addtree()->num_leaf_values()) ==
        this->get_output().size());
#endif
}

} // namespace police
