#pragma once

#include "police/nnlp.hpp"

#include <memory>

namespace police {

class NNLPWrapper : public NNLP {
public:
    explicit NNLPWrapper(std::unique_ptr<NNLP> lp);

    void push_snapshot() override;

    void pop_snapshot() override;

    void clear() override;

    size_t add_variable(const VariableType& var_type) override;

    void set_input_index(size_t var_id, size_t index) override;

    void set_output_index(size_t var_id, size_t index) override;

    void set_variable_upper_bound(size_t var_ref, real_t ub) override;

    void set_variable_lower_bound(size_t var_ref, real_t lb) override;

    void add_assumption(const linear_constraint_type& constraint) override;

    void add_constraint(const linear_constraint_type& constraint) override;

    void add_constraint(const relu_constraint_type& constraint) override;

    void add_constraint(const max_constraint_type& constraint) override;

    void add_constraint(
        const linear_constraint_disjunction_type& constraint) override;

    [[nodiscard]]
    Status solve() override;

    [[nodiscard]]
    model_type get_model() const override;

    [[nodiscard]]
    bool has_integer_variable() const override
    {
        return lp_->has_integer_variable();
    }

    [[nodiscard]]
    size_t num_variables() const override
    {
        return lp_->num_variables();
    }

    [[nodiscard]]
    real_t get_variable_lower_bound(size_t var_ref) const override
    {
        return lp_->get_variable_lower_bound(var_ref);
    }

    [[nodiscard]]
    real_t get_variable_upper_bound(size_t var_ref) const override
    {
        return lp_->get_variable_upper_bound(var_ref);
    }

    void dump() const override { lp_->dump(); }

    const NNLP* get_underlying_lp() const;

    NNLP* get_underlying_lp();

private:
    std::unique_ptr<NNLP> lp_;
};

} // namespace police
