#pragma once

#include "police/compute_graph.hpp"
#include "police/vbased_policy.hpp"

#include <memory>

namespace police {

namespace internal {
class CGEvaluator {
public:
    explicit CGEvaluator(std::shared_ptr<cg::Node> node);

    [[nodiscard]]
    vector<real_t> operator()(const vector<real_t>& input) const;

    [[nodiscard]]
    const std::shared_ptr<cg::Node>& get_compute_graph() const
    {
        return node_;
    }

private:
    std::shared_ptr<cg::Node> node_;
};
} // namespace internal

class CGPolicy final : public VBasedPolicy<internal::CGEvaluator> {
public:
    CGPolicy(
        std::shared_ptr<cg::Node> vfunc,
        vector<size_t> input_vars,
        vector<size_t> output_actions,
        size_t num_actions)
        : VBasedPolicy<internal::CGEvaluator>(
              internal::CGEvaluator(std::move(vfunc)),
              std::move(input_vars),
              std::move(output_actions),
              num_actions)
    {
    }

    [[nodiscard]]
    const std::shared_ptr<cg::Node>& get_compute_graph() const
    {
        return this->get_value_function().get_compute_graph();
    }
};

} // namespace police
