#include "police/marabou_preprocessor.hpp"
#include "police/base_types.hpp"
#include "police/nnlp_marabou.hpp"

#include <algorithm>
#include <cmath>
#include <memory>
#include <type_traits>

namespace police {

NNLPMarabouPreprocessor::NNLPMarabouPreprocessor(std::shared_ptr<NNLP> sub_nnlp)
    : sub_nnlp_(std::move(sub_nnlp))
    , marabou_(std::make_unique<MarabouLP>())
{
}

void NNLPMarabouPreprocessor::push_snapshot()
{
    marabou_->push_snapshot();
    sub_nnlp_->push_snapshot();
}

void NNLPMarabouPreprocessor::pop_snapshot()
{
    marabou_->pop_snapshot();
    sub_nnlp_->pop_snapshot();
    is_integer_var_.resize(num_variables());
}

void NNLPMarabouPreprocessor::clear()
{
    marabou_->clear();
    sub_nnlp_->clear();
    is_integer_var_.clear();
}

size_t NNLPMarabouPreprocessor::add_variable(const VariableType& var_type)
{
    sub_nnlp_->add_variable(var_type.unbounded());
    std::visit(
        [this](auto&& t) {
            using T = std::decay_t<decltype(t)>;
            is_integer_var_.push_back(
                !std::is_same_v<T, RealType> &&
                !std::is_same_v<T, BoundedRealType>);
        },
        var_type);
    return marabou_->add_variable(var_type);
}

void NNLPMarabouPreprocessor::set_variable_upper_bound(
    size_t var_ref,
    real_t ub)
{
    marabou_->set_variable_upper_bound(var_ref, ub);
}

void NNLPMarabouPreprocessor::set_variable_lower_bound(
    size_t var_ref,
    real_t lb)
{
    marabou_->set_variable_lower_bound(var_ref, lb);
}

void NNLPMarabouPreprocessor::add_constraint(
    const linear_constraint_type& constraint)
{
    marabou_->add_constraint(constraint);
    sub_nnlp_->add_constraint(constraint);
}

void NNLPMarabouPreprocessor::add_constraint(
    const relu_constraint_type& constraint)
{
    marabou_->add_constraint(constraint);
    sub_nnlp_->add_constraint(constraint);
}

void NNLPMarabouPreprocessor::add_constraint(
    const max_constraint_type& constraint)
{
    marabou_->add_constraint(constraint);
    sub_nnlp_->add_constraint(constraint);
}

void NNLPMarabouPreprocessor::add_constraint(
    const linear_constraint_disjunction_type& constraint)
{
    marabou_->add_constraint(constraint);
    sub_nnlp_->add_constraint(constraint);
}

bool NNLPMarabouPreprocessor::add_variable_bounds()
{
    const auto pre = marabou_->preprocess();
    if (pre.infeasible) {
        return false;
    }
    assert(pre.bounds.size() == num_variables());
    for (int var_id = num_variables() - 1; var_id >= 0; --var_id) {
        expressions::Variable var(var_id);
        if (pre.bounds[var_id].has_lb()) {
            if (is_integer_var_[var_id]) {
                sub_nnlp_->set_variable_lower_bound(
                    var_id,
                    std::ceil(pre.bounds[var_id].lb - number_utils::EPSILON));
            } else {
                sub_nnlp_->set_variable_lower_bound(
                    var_id,
                    pre.bounds[var_id].lb);
            }
        } else {
            sub_nnlp_->set_variable_lower_bound(var_id, NO_LB);
        }
        if (pre.bounds[var_id].has_ub()) {
            if (is_integer_var_[var_id]) {
                sub_nnlp_->set_variable_upper_bound(
                    var_id,
                    std::floor(pre.bounds[var_id].ub + number_utils::EPSILON));
            } else {
                sub_nnlp_->set_variable_upper_bound(
                    var_id,
                    pre.bounds[var_id].ub);
            }
        } else {
            sub_nnlp_->set_variable_upper_bound(var_id, NO_UB);
        }
    }
    return true;
}

bool NNLPMarabouPreprocessor::preprocess()
{
    if (assumptions_.size() > 0u) {
        marabou_->push_snapshot();
        for (const auto& ass : assumptions_) {
            marabou_->add_constraint(ass);
        }
    }
    const bool feasible = add_variable_bounds();
    if (assumptions_.size() > 0u) {
        marabou_->pop_snapshot();
    }
    return feasible;
}

NNLP::Status NNLPMarabouPreprocessor::solve()
{
    if (!preprocess()) {
        return NNLP::Status::UNSAT;
    }
    for (const auto& ass : assumptions_) {
        sub_nnlp_->add_assumption(ass);
    }
    assumptions_.clear();
    return sub_nnlp_->solve();
}

bool NNLPMarabouPreprocessor::has_integer_variable() const
{
    return std::any_of(
        is_integer_var_.begin(),
        is_integer_var_.end(),
        [](bool x) { return x; });
}

size_t NNLPMarabouPreprocessor::num_variables() const
{
    return marabou_->num_variables();
}

void NNLPMarabouPreprocessor::set_input_index(size_t var, size_t index)
{
    marabou_->set_input_index(var, index);
    sub_nnlp_->set_input_index(var, index);
}

void NNLPMarabouPreprocessor::set_output_index(size_t var, size_t index)
{
    marabou_->set_output_index(var, index);
    sub_nnlp_->set_output_index(var, index);
}

real_t NNLPMarabouPreprocessor::get_variable_lower_bound(size_t var_ref) const
{
    return marabou_->get_variable_lower_bound(var_ref);
}

real_t NNLPMarabouPreprocessor::get_variable_upper_bound(size_t var_ref) const
{
    return marabou_->get_variable_upper_bound(var_ref);
}

void NNLPMarabouPreprocessor::dump() const
{
    marabou_->dump();
    // do_dump();
}

void NNLPMarabouPreprocessor::add_assumption(
    const linear_constraint_type& constraint)
{
    // marabou_->add_assumption(constraint);
    // sub_nnlp_->add_assumption(constraint);
    assumptions_.push_back(constraint);
}

bool NNLPMarabouPreprocessor::supports_unsolvable_core() const
{
    return sub_nnlp_->supports_unsolvable_core();
}

LinearConstraintConjunction NNLPMarabouPreprocessor::get_unsolvable_core() const
{
    return sub_nnlp_->get_unsolvable_core();
}

NNLP::model_type NNLPMarabouPreprocessor::get_model() const
{
    return sub_nnlp_->get_model();
}

} // namespace police
