#pragma once

#include "police/expressions/boolean_combination.hpp"
#include "police/expressions/expression.hpp"
#include "police/expressions/variable.hpp"
#include "police/smt.hpp"

#include <ranges>
#include <type_traits>

namespace police {

template <typename Callback>
concept smt_model_callback =
    requires(Callback c) { c(std::declval<SMT::model_type>()); };

class SMTModelEnumerator {
private:
    template <smt_model_callback Callback>
    bool
    invoke(std::true_type, Callback&& callback, const SMT::model_type& model)
        const
    {
        return callback(model);
    }

    template <smt_model_callback Callback>
    bool invoke(Callback&& callback, const SMT::model_type& model) const
    {
        if constexpr (std::is_convertible_v<
                          std::
                              invoke_result_t<Callback, const SMT::model_type&>,
                          bool>) {
            return callback(model);
        } else {
            callback(model);
            return false;
        }
    }

public:
    template <smt_model_callback Callback, std::ranges::range R>
    size_t operator()(SMT* smt, Callback callback, R&& vars) const
    {
        size_t num = 0;
        smt->push_snapshot();
        auto result = smt->solve();
        while (result == SMT::Status::SAT) {
            ++num;
            const auto model = smt->get_model();
            if (invoke(callback, model)) {
                break;
            }
            expressions::Disjunction excl({});
            for (auto i = vars.begin(); i != vars.end(); ++i) {
                excl.children.push_back(expressions::not_equal(
                    expressions::Variable(*i),
                    model.get_value(*i)));
            }
            smt->add_constraint(excl);
            result = smt->solve();
        }
        smt->pop_snapshot();
        return num;
    }

    template <smt_model_callback Callback>
    size_t operator()(SMT* smt, Callback callback) const
    {
        return operator()(
            smt,
            std::move(callback),
            std::ranges::iota_view{0u, smt->num_variables()});
    }
};

} // namespace police
