#pragma once

#include "police/base_types.hpp"
#include "police/lp.hpp"
#include "police/relu_constraint.hpp"
#include "police/smt.hpp"
#include "police/storage/segmented_vector.hpp"
#include "police/storage/unordered_map.hpp"
#include "police/storage/variable_space.hpp"

#include <limits>

namespace police {

class NNLP {
public:
    enum Status {
        SAT,
        UNSAT,
    };
    using model_type = SMT::model_type;
    using linear_constraint_type = LinearConstraint;
    using relu_constraint_type = ReluConstraint;
    using max_constraint_type = MaxConstraint;
    using linear_constraint_disjunction_type = LinearConstraintDisjunction;

    static constexpr real_t NO_LB = -std::numeric_limits<real_t>::infinity();
    static constexpr real_t NO_UB = std::numeric_limits<real_t>::infinity();

    virtual ~NNLP() = default;

    virtual void set_input_index(
        [[maybe_unused]] size_t var_id,
        [[maybe_unused]] size_t index)
    {
    }

    virtual void set_output_index(
        [[maybe_unused]] size_t var_id,
        [[maybe_unused]] size_t index)
    {
    }

    virtual void push_snapshot() = 0;

    virtual void pop_snapshot() = 0;

    virtual void clear() = 0;

    virtual size_t add_variable(const VariableType& var_type) = 0;

    size_t add_variables(const VariableSpace& vspace);

    virtual void set_variable_upper_bound(size_t var_ref, real_t ub) = 0;

    virtual void set_variable_lower_bound(size_t var_ref, real_t lb) = 0;

    void tighten_variable_bounds(size_t var_ref, real_t lb, real_t ub);

    virtual void add_assumption(const linear_constraint_type& constraint) = 0;

    // virtual void
    // add_assumption(const linear_constraint_disjunction_type& constraint) = 0;

    virtual void add_constraint(const linear_constraint_type& constraint) = 0;

    virtual void add_constraint(const relu_constraint_type& constraint) = 0;

    virtual void add_constraint(const max_constraint_type& constraint) = 0;

    virtual void
    add_constraint(const linear_constraint_disjunction_type& constraint) = 0;

    [[nodiscard]]
    virtual Status solve() = 0;

    [[nodiscard]]
    virtual model_type get_model() const = 0;

    [[nodiscard]]
    virtual bool has_integer_variable() const = 0;

    [[nodiscard]]
    virtual size_t num_variables() const = 0;

    [[nodiscard]]
    virtual real_t get_variable_lower_bound(size_t var_ref) const = 0;

    [[nodiscard]]
    virtual real_t get_variable_upper_bound(size_t var_ref) const = 0;

    [[nodiscard]]
    virtual bool has_lower_bound(size_t var_ref) const
    {
        return get_variable_lower_bound(var_ref) != NO_LB;
    }

    [[nodiscard]]
    virtual bool has_upper_bound(size_t var_ref) const
    {
        return get_variable_upper_bound(var_ref) != NO_UB;
    }

    virtual void dump() const {}

    virtual bool supports_unsolvable_core() const { return false; }

    virtual LinearConstraintConjunction get_unsolvable_core() const
    {
        return {};
    }
};

class NNLPBase : public NNLP {
public:
    NNLPBase();

    void push_snapshot() final;

    void pop_snapshot() final;

    void clear() final;

    void add_assumption(const linear_constraint_type& constraint) override;

    // void add_assumption(
    //     const linear_constraint_disjunction_type& constraint) override;

    size_t add_variable(const VariableType& var_type) final;

    void set_variable_upper_bound(size_t var_ref, real_t ub) final;

    void set_variable_lower_bound(size_t var_ref, real_t lb) final;

    [[nodiscard]]
    const VariableSpace& get_variable_space() const;

    [[nodiscard]]
    bool has_integer_variable() const final;

    [[nodiscard]]
    size_t num_variables() const final;

    [[nodiscard]]
    real_t get_variable_lower_bound(size_t var_ref) const final;

    [[nodiscard]]
    real_t get_variable_upper_bound(size_t var_ref) const final;

    [[nodiscard]]
    Status solve() final;

protected:
    void notify_variable_upper_bound(size_t var_ref, real_t ub);

    void notify_variable_lower_bound(size_t var_ref, real_t lb);

    struct bound_type {
        [[nodiscard]]
        bool has_lb() const
        {
            return lb != NO_LB;
        }

        [[nodiscard]]
        bool has_ub() const
        {
            return ub != NO_UB;
        }

        real_t lb;
        real_t ub;
    };

    struct bound_iterator {
        using value_type = bound_type;
        using difference_type = int_t;

        bound_iterator() = default;

        bound_iterator(
            const vector<real_t>* lbs,
            const vector<real_t>* ubs,
            size_t index)
            : lbs_(lbs)
            , ubs_(ubs)
            , idx_(index)
        {
        }

        [[nodiscard]]
        bool operator==(const bound_iterator& other) const
        {
            return idx_ == other.idx_;
        }

        [[nodiscard]]
        auto operator<=>(const bound_iterator& other) const
        {
            return idx_ <=> other.idx_;
        }

        [[nodiscard]]
        difference_type operator-(const bound_iterator& other) const
        {
            return static_cast<difference_type>(idx_) -
                   static_cast<difference_type>(other.idx_);
        }

        [[nodiscard]]
        value_type operator*() const
        {
            return {lbs_->at(idx_), ubs_->at(idx_)};
        }

        [[nodiscard]]
        value_type operator[](difference_type pos) const
        {
            return *(*this + pos);
        }

        bound_iterator& operator++()
        {
            ++idx_;
            return *this;
        }

        bound_iterator operator++(int)
        {
            bound_iterator copy(*this);
            ++*this;
            return copy;
        }

        bound_iterator& operator--()
        {
            --idx_;
            return *this;
        }

        bound_iterator operator--(int)
        {
            bound_iterator copy(*this);
            ++*this;
            return copy;
        }

        bound_iterator& operator+=(difference_type pos)
        {
            idx_ += pos;
            return *this;
        }

        bound_iterator& operator-=(difference_type pos)
        {
            idx_ -= pos;
            return *this;
        }

        [[nodiscard]]
        bound_iterator operator+(difference_type pos) const
        {
            bound_iterator copy(*this);
            copy += pos;
            return copy;
        }

        [[nodiscard]]
        bound_iterator operator-(difference_type pos) const
        {
            bound_iterator copy(*this);
            copy -= pos;
            return copy;
        }

        [[nodiscard]]
        friend bound_iterator
        operator+(difference_type pos, const bound_iterator& it)
        {
            bound_iterator copy(it);
            copy += pos;
            return copy;
        }

        [[nodiscard]]
        friend bound_iterator
        operator-(difference_type pos, const bound_iterator& it)
        {
            bound_iterator copy(it);
            copy -= pos;
            return copy;
        }

    private:
        const vector<real_t>* lbs_ = nullptr;
        const vector<real_t>* ubs_ = nullptr;
        size_t idx_ = 0;
    };

    [[nodiscard]]
    bound_iterator vars_begin() const
    {
        return {&var_lbs_, &var_ubs_, 0};
    }

    [[nodiscard]]
    bound_iterator vars_end() const
    {
        return {&var_lbs_, &var_ubs_, num_variables()};
    }

private:
    virtual void do_push_snapshot() = 0;
    virtual void do_pop_snapshot() = 0;
    virtual void do_clear() = 0;
    virtual void do_add_variable(const VariableType& var_type) = 0;
    virtual void do_set_variable_upper_bound(size_t var_ref, real_t ub) = 0;
    virtual void do_set_variable_lower_bound(size_t var_ref, real_t lb) = 0;
    [[nodiscard]]
    virtual Status do_solve() = 0;
    [[nodiscard]]
    virtual Status
    do_solve(const vector<linear_constraint_type>& ass_constraints);

    struct JournalEntry {
        JournalEntry() = default;

        JournalEntry(size_t num_vars, size_t num_int_vars)
            : num_vars(num_vars)
            , num_int_vars(num_int_vars)
        {
        }

        unordered_map<size_t, real_t> lbs;
        unordered_map<size_t, real_t> ubs;
        size_t num_vars = 0;
        size_t num_int_vars = 0;
    };

    vector<real_t> var_lbs_;
    vector<real_t> var_ubs_;
    segmented_vector<JournalEntry> chg_log_;
    VariableSpace vspace_;

    vector<linear_constraint_type> ass_constraints_;
    bool has_assumptions_ = false;

    size_t num_integer_vars_ = 0;
};

} // namespace police
