#pragma once

#include "police/marabou_preprocessor.hpp"
#include "police/nnlp.hpp"
#include "police/nnlp_factory.hpp"
#include "police/smt_factory.hpp"

#include <memory>

namespace police {

class NNLPSMT final : public NNLPBase {
public:
    explicit NNLPSMT(std::unique_ptr<SMT> smt);

    void add_constraint(const linear_constraint_type& constraint) override;

    void add_constraint(const relu_constraint_type& constraint) override;

    void add_constraint(const max_constraint_type& constraint) override;

    void add_constraint(
        const linear_constraint_disjunction_type& constraint) override;

    [[nodiscard]]
    model_type get_model() const override;

    void dump() const override;

    bool supports_unsolvable_core() const override;

    LinearConstraintConjunction get_unsolvable_core() const override;

private:
    void do_push_snapshot() override;
    void do_pop_snapshot() override;
    void do_clear() override;
    void do_add_variable(const VariableType& var_type) override;
    void do_set_variable_upper_bound(size_t var_ref, real_t ub) override;
    void do_set_variable_lower_bound(size_t var_ref, real_t lb) override;
    [[nodiscard]]
    Status do_solve() override;
    [[nodiscard]]
    Status
    do_solve(const vector<linear_constraint_type>& ass_constraints) override;
    void add_bounds_to_smt() const;
    void remove_bounds_from_smt() const;

    std::unique_ptr<SMT> smt_;
    mutable bool needs_a_clean_ = false;
};

// class NNLPSMTWithPreprocessing final : public NNLPMarabouHelper {
// public:
//     explicit NNLPSMTWithPreprocessing(std::unique_ptr<SMT> smt);
//
//     [[nodiscard]]
//     model_type get_model() const override;
//
// private:
//     void do_add_constraint(const linear_constraint_type& constraint)
//     override; void do_add_constraint(const relu_constraint_type& constraint)
//     override; void do_add_constraint(const max_constraint_type& constraint)
//     override; void do_add_constraint(
//         const linear_constraint_disjunction_type& constraint) override;
//
//     void do_push_snapshot() override;
//     void do_pop_snapshot() override;
//     void do_clear() override;
//     void do_add_variable(const VariableType& var_type) override;
//
//     [[nodiscard]]
//     Status do_solve() override;
//
//     [[nodiscard]]
//     Status
//     do_solve(const vector<linear_constraint_type>& ass_constraints) override;
//
//     std::unique_ptr<SMT> smt_;
// };

class NNLPSMTFactory : public NNLPFactory {
public:
    NNLPSMTFactory(const SMTFactory* factory, bool preprocess)
        : factory_(factory)
        , preprocess_(preprocess)
    {
    }

    NNLP* make() const override
    {
        if (preprocess_)
            return new NNLPMarabouPreprocessor(
                std::make_shared<NNLPSMT>(factory_->make_unique()));
        return new NNLPSMT(factory_->make_unique());
    }

private:
    const SMTFactory* factory_;
    bool preprocess_;
};

} // namespace police
