#pragma once

#include "police/nnlp.hpp"
#include "police/nnlp_factory.hpp"

#if POLICE_MARABOU
#include <memory>
#endif

namespace police {

class MarabouLP final : public NNLPBase {
public:
    struct PreprocessedBounds {
        [[nodiscard]]
        bool has_lb() const
        {
            return lb != NNLP::NO_LB;
        }

        [[nodiscard]]
        bool has_ub() const
        {
            return ub != NNLP::NO_UB;
        }

        real_t lb = NNLP::NO_LB;
        real_t ub = NNLP::NO_UB;
    };

    struct PreprocessorResult {
        vector<PreprocessedBounds> bounds;
        bool infeasible = false;
    };

    MarabouLP();

    void set_input_index(size_t var_id, size_t index) override;

    virtual void set_output_index(size_t var_id, size_t index) override;

    void add_constraint(const linear_constraint_type& constraint) override;

    void add_constraint(const relu_constraint_type& constraint) override;

    void add_constraint(const max_constraint_type& constraint) override;

    void add_constraint(
        const linear_constraint_disjunction_type& constraint) override;

    [[nodiscard]]
    model_type get_model() const override;

    PreprocessorResult preprocess() const;

    void dump() const override;

private:
    void do_push_snapshot() override;
    void do_pop_snapshot() override;
    void do_clear() override;
    void do_add_variable(const VariableType& var_type) override;
    void do_set_variable_upper_bound(size_t var_ref, real_t ub) override;
    void do_set_variable_lower_bound(size_t var_ref, real_t lb) override;
    [[nodiscard]]
    Status do_solve() override;

#if POLICE_MARABOU
    void cleanup_marabou_query() const;
    Status solve_internal() const;

    struct MarabouInternals;

    struct MarabouInternalsDeleter {
        void operator()(MarabouInternals* internals) const;
    };

    vector<std::pair<size_t, size_t>> input_vars_;
    vector<std::pair<size_t, size_t>> output_vars_;
    std::unique_ptr<MarabouInternals, MarabouInternalsDeleter> mb_;
#endif
};

class MarabouLPFactory : public NNLPFactory {
public:
    [[nodiscard]]
    NNLP* make() const override
    {
        return new MarabouLP();
    }
};

} // namespace police
