#pragma once

#include "police/addtree.hpp"
#include "police/base_types.hpp"
#include "police/storage/vector.hpp"
#include "police/vbased_policy.hpp"

#include <cassert>
#include <memory>

namespace police {

namespace details {
class AddTreeEvaluator {
public:
    explicit AddTreeEvaluator(std::shared_ptr<AddTree> addtree);

    [[nodiscard]]
    vector<real_t> operator()(const vector<real_t>& features) const;

    [[nodiscard]]
    const std::shared_ptr<AddTree>& get_addtree() const
    {
        return addtree_;
    }

private:
    std::shared_ptr<AddTree> addtree_;
};
} // namespace details

class AddTreePolicy : public VBasedPolicy<details::AddTreeEvaluator> {
public:
    AddTreePolicy(
        std::shared_ptr<AddTree> addtree,
        vector<size_t> input_vars,
        vector<size_t> output_actions,
        size_t num_actions);

    [[nodiscard]]
    const std::shared_ptr<AddTree>& get_addtree() const
    {
        return this->get_value_function().get_addtree();
    }
};

} // namespace police
