#include "police/nnlp_marabou.hpp"
#include "police/macros.hpp"

#if POLICE_MARABOU

#include "police/lp.hpp"
#include "police/option.hpp"
#include "police/storage/segmented_vector.hpp"
#include "police/storage/value.hpp"
#include "police/storage/variable_space.hpp"

#include <Marabou/DisjunctionConstraint.h>
#include <Marabou/Engine.h>
#include <Marabou/EngineState.h>
#include <Marabou/Equation.h>
#include <Marabou/FloatUtils.h>
#include <Marabou/MarabouError.h>
#include <Marabou/MaxConstraint.h>
#include <Marabou/NetworkLevelReasoner.h>
#include <Marabou/Options.h>
#include <Marabou/PiecewiseLinearCaseSplit.h>
#include <Marabou/PiecewiseLinearConstraint.h>
#include <Marabou/Query.h>
#include <Marabou/ReluConstraint.h>
#include <Marabou/SymbolicBoundTighteningType.h>
#include <Marabou/Tightening.h>

#include <algorithm>
#include <memory>

#define MARABOU_ERROR(err)                                                     \
    std::cerr << "Marabou error code " << err.getCode() << ": "                \
              << err.getUserMessage() << std::endl

namespace police {

namespace {

Equation::EquationType to_equation_type(LinearConstraint::Type t)
{
    switch (t) {
    case LinearConstraint::Type::EQUAL: return Equation::EQ;
    case LinearConstraint::Type::LESS_EQUAL: return Equation::LE;
    case LinearConstraint::Type::GREATER_EQUAL: return Equation::GE;
    }
    POLICE_UNREACHABLE();
}

Equation to_equation(const LinearConstraint& constraint)
{
    Equation eq(to_equation_type(constraint.type));
    eq.setScalar(constraint.rhs);
    for (const auto& elem : constraint) {
        eq.addAddend(elem.second, elem.first);
    }
    return eq;
}

PiecewiseLinearConstraint* to_pwl_constraint(const ReluConstraint& constraint)
{
    return new ::ReluConstraint(constraint.x, constraint.y);
}

template <typename Iterator>
::Set<unsigned> to_set(Iterator begin, Iterator end)
{
    ::Set<unsigned> res;
    for (; begin != end; ++begin) {
        res.insert(*begin);
    }
    return res;
}

PiecewiseLinearConstraint* to_pwl_constraint(const MaxConstraint& constraint)
{
    return new ::MaxConstraint(
        constraint.y,
        to_set(constraint.elements.begin(), constraint.elements.end()));
}

struct MarabouModel {
    MarabouModel(const Query* query, const VariableSpace* vspace)
        : query(query)
        , vspace(vspace)
    {
    }

    Value get_value(size_t var_id) const
    {
        const VariableSpace::value& var = vspace->at(var_id);
        return Value(static_cast<real_t>(query->getSolutionValue(var.id)));
    }

    const VariableSpace* get_variable_space() const { return vspace; }

    size_t size() const { return vspace->size(); }

    const Query* query;
    const VariableSpace* vspace;
};

// void performSymbolicBoundTightening(
//     InputQuery& query,
//     SymbolicBoundTighteningType type =
//         SymbolicBoundTighteningType::SYMBOLIC_BOUND_TIGHTENING)
// {
//     auto* reasoner = query.getNetworkLevelReasoner();
//     assert(reasoner);
//     reasoner->obtainCurrentBounds();
//     switch (type) {
//     case SymbolicBoundTighteningType::SYMBOLIC_BOUND_TIGHTENING:
//         reasoner->symbolicBoundPropagation();
//         break;
//     case SymbolicBoundTighteningType::DEEP_POLY:
//         reasoner->deepPolyPropagation();
//         break;
//     default: break;
//     }
//     List<Tightening> tightenings;
//     reasoner->getConstraintTightenings(tightenings);
//     for (auto& tight : tightenings) {
//         switch (tight._type) {
//         case Tightening::LB: {
//             const auto current = query.getLowerBound(tight._variable);
//             if (tight._value > current) {
//                 query.setLowerBound(tight._variable, tight._value);
//             }
//         }
//         case Tightening::UB: {
//             const auto current = query.getUpperBound(tight._variable);
//             if (tight._value < current) {
//                 query.setUpperBound(tight._variable, tight._value);
//             }
//         }
//         }
//     }
// }

[[maybe_unused]]
bool disjunctive_constraints_are_valid(const Query& q)
{
    return std::all_of(
        q.getPiecewiseLinearConstraints().begin(),
        q.getPiecewiseLinearConstraints().end(),
        [](const PiecewiseLinearConstraint* c) {
            const DisjunctionConstraint* d =
                dynamic_cast<const DisjunctionConstraint*>(c);
            if (d == nullptr) {
                return true;
            }
            const auto splits = d->getCaseSplits();
            return std::all_of(
                splits.begin(),
                splits.end(),
                [](const PiecewiseLinearCaseSplit& split) {
                    return !split.getEquations().empty() ||
                           !split.getBoundTightenings().empty();
                });
        });
}

void set_marabou_options()
{
    static ::Options* opts = nullptr;
    if (opts == nullptr) {
        opts = ::Options::get();
        assert(
            opts->getString(::Options::StringOptions::LP_SOLVER) == "gurobi");
        opts->setFloat(
            ::Options::FloatOptions::PREPROCESSOR_BOUND_TOLERANCE,
            1e-2);
        opts->setInt(::Options::IntOptions::SEED, 1734);
        // opts->setBool(
        //     ::Options::BoolOptions::PERFORM_LP_TIGHTENING_AFTER_SPLIT,
        //     false);
        // opts->setBool(::Options::BoolOptions::SOLVE_WITH_MILP, true);
        // opts->setString(
        //     ::Options::StringOptions::SOI_SEARCH_STRATEGY,
        //     "walksat");
        // opts->setString(
        //     ::Options::StringOptions::SYMBOLIC_BOUND_TIGHTENING_TYPE,
        //     "none");
        // opts->setString(
        //     ::Options::StringOptions::MILP_SOLVER_BOUND_TIGHTENING_TYPE,
        //     "none");
        opts->setInt(Options::VERBOSITY, 0);
    }
}

} // namespace

struct MarabouLP::MarabouInternals {
    MarabouInternals()
    {
        set_marabou_options();
        // correctly initialize variable count (Marabou leaves that variable
        // uninitialized)
        query.setNumberOfVariables(0);
    }

    void push(size_t num_vars)
    {
        variables.push_back(num_vars);
        query_stack.push_back(query);
    }

    void pop()
    {
        variables.pop_back();
        query = std::move(query_stack.back());
        query_stack.pop_back();
    }

    segmented_vector<size_t> variables;
    vector<Query> query_stack;
    Query query;
    std::unique_ptr<::Engine> engine = nullptr;
};

void MarabouLP::MarabouInternalsDeleter::operator()(
    MarabouInternals* internals) const
{
    delete internals;
}

MarabouLP::MarabouLP()
{
    mb_.reset(new MarabouInternals());
}

void MarabouLP::do_push_snapshot()
{
    mb_->push(num_variables());
}

void MarabouLP::do_pop_snapshot()
{
    mb_->pop();
}

void MarabouLP::do_clear()
{
    mb_.reset(new MarabouInternals());
    input_vars_.clear();
    output_vars_.clear();
}

void MarabouLP::do_add_variable(const VariableType&)
{
    auto& q = mb_->query;
    q.getNewVariable();
    assert(q.getNumberOfVariables() == num_variables());
}

void MarabouLP::add_constraint(const linear_constraint_type& constraint)
{
    if (constraint.empty()) return;
    if (constraint.size() == 1u) {
        const auto c =
            constraint.begin()->second < 0. ? -constraint : constraint;
        const auto var = c.begin()->first;
        const auto coef = c.begin()->second;
        const auto value = c.rhs / coef;
        switch (c.type) {
        case police::LinearConstraint::LESS_EQUAL:
            tighten_variable_bounds(var, NO_LB, value);
            break;
        case police::LinearConstraint::GREATER_EQUAL:
            tighten_variable_bounds(var, value, NO_UB);
            break;
        case police::LinearConstraint::EQUAL:
            tighten_variable_bounds(var, value, value);
            break;
        }
    } else {
        auto eq = to_equation(constraint);
        mb_->query.addEquation(std::move(eq));
    }
}

void MarabouLP::add_constraint(const ReluConstraint& constraint)
{
    mb_->query.addPiecewiseLinearConstraint(to_pwl_constraint(constraint));
}

void MarabouLP::add_constraint(const MaxConstraint& constraint)
{
    mb_->query.addPiecewiseLinearConstraint(to_pwl_constraint(constraint));
}

void MarabouLP::add_constraint(
    const linear_constraint_disjunction_type& constraint)
{
#if 1
    List<PiecewiseLinearCaseSplit> splits;
    std::for_each(
        constraint.begin(),
        constraint.end(),
        [&](const LinearConstraint& constraint) {
            assert(!constraint.empty());
            PiecewiseLinearCaseSplit split;
            if (constraint.size() == 1) {
                const auto norm =
                    (constraint.begin()->second < 0. ? -constraint
                                                     : constraint);
                const auto var = norm.begin()->first;
                const auto value = norm.rhs / norm.begin()->second;
                assert(norm.begin()->second > 0.);
                if (norm.type != LinearConstraint::Type::GREATER_EQUAL) {
                    split.storeBoundTightening(
                        Tightening(var, value, Tightening::UB));
                }
                if (norm.type != LinearConstraint::Type::LESS_EQUAL) {
                    split.storeBoundTightening(
                        Tightening(var, value, Tightening::LB));
                }
            } else {
                Equation eq = to_equation(constraint);
                split.addEquation(eq);
            }
            assert(
                !split.getBoundTightenings().empty() ||
                !split.getEquations().empty());
            splits.append(std::move(split));
        });
    assert(std::all_of(splits.begin(), splits.end(), [](const auto& split) {
        return !split.getBoundTightenings().empty() ||
               !split.getEquations().empty();
    }));
    mb_->query.addPiecewiseLinearConstraint(
        new ::DisjunctionConstraint(std::move(splits)));
#else
    PiecewiseLinearCaseSplit split;
    std::for_each(
        constraint.begin(),
        constraint.end(),
        [&](const LinearConstraint& constraint) {
            assert(!constraint.empty());
            if (constraint.size() == 1) {
                const auto& elem = *constraint.begin();
                if (constraint.type != LinearConstraint::Type::GREATER_EQUAL) {
                    split.storeBoundTightening(Tightening(
                        elem.first,
                        constraint.rhs / elem.second,
                        Tightening::LB));
                }
                if (constraint.type != LinearConstraint::Type::LESS_EQUAL) {
                    split.storeBoundTightening(Tightening(
                        elem.first,
                        constraint.rhs / elem.second,
                        Tightening::UB));
                }
            } else {
                Equation eq = to_equation(constraint);
                split.addEquation(eq);
            }
        });
    List<PiecewiseLinearCaseSplit> splits;
    splits.append(std::move(split));
    mb_->query.addPiecewiseLinearConstraint(
        new ::DisjunctionConstraint(std::move(splits)));

#endif
}

void MarabouLP::do_set_variable_upper_bound(size_t, real_t)
{
    // bounds will be set at each solver call (as marabou modifies the bounds in
    // the query during its solving process)
}

void MarabouLP::do_set_variable_lower_bound(size_t, real_t)
{
    // bounds will be set at each solver call (as marabou modifies the bounds in
    // the query during its solving process)
}

void MarabouLP::cleanup_marabou_query() const
{
    mb_->query.clearBounds();
    size_t var = 0;
    for (auto it = vars_begin(); it != vars_end(); ++it, ++var) {
        const auto& bounds = *it;
        if (bounds.has_lb()) {
            mb_->query.setLowerBound(var, bounds.lb);
        }
        if (bounds.has_ub()) {
            mb_->query.setUpperBound(var, bounds.ub);
        }
    }
}

NNLP::Status MarabouLP::solve_internal() const
{
    auto& engine = mb_->engine;
    engine.reset(new ::Engine());
    cleanup_marabou_query();
    if (engine->processInputQuery(mb_->query)) {
        engine->solve();
        switch (engine->getExitCode()) {
        case ::Engine::SAT: return Status::SAT;
        case ::Engine::UNSAT: return Status::UNSAT;
        default: break;
        }
    }
    return Status::UNSAT;
}

NNLP::Status MarabouLP::do_solve()
{
    try {
        const auto result = solve_internal();
        // cleanup_marabou_query();
        // mb_->engine.reset();
        return result;
    } catch (const MarabouError& err) {
        MARABOU_ERROR(err);
        POLICE_RUNTIME_ERROR("Marabou has thrown an exception");
    }
}

NNLP::model_type MarabouLP::get_model() const
{
    auto& engine = mb_->engine;
    engine->extractSolution(mb_->query);
    return MarabouModel(&mb_->query, &get_variable_space());
}

MarabouLP::PreprocessorResult MarabouLP::preprocess() const
{
    cleanup_marabou_query();
    auto& query = mb_->query;
    try {
        ::Engine engine;
        if (engine.calculateBounds(query)) {
            engine.extractBounds(query);
            vector<PreprocessedBounds> bounds(num_variables());
            for (int var = num_variables() - 1; var >= 0; --var) {
                const auto lb = query.getLowerBound(var);
                bounds[var].lb = ::FloatUtils::isInf(lb) ? NO_LB : lb;
                const auto ub = query.getUpperBound(var);
                bounds[var].ub = ::FloatUtils::isInf(ub) ? NO_UB : ub;
            }
            return {std::move(bounds), false};
        }
        return {{}, true};
    } catch (const MarabouError& err) {
        MARABOU_ERROR(err);
        POLICE_RUNTIME_ERROR("Marabou has thrown an exception");
    }
}

void MarabouLP::dump() const
{
    mb_->query.saveQuery("query.marabou");
}

void MarabouLP::set_input_index(size_t var_id, size_t index)
{
    input_vars_.emplace_back(var_id, index);
    mb_->query.markInputVariable(var_id, index);
}

void MarabouLP::set_output_index(size_t var_id, size_t index)
{
    output_vars_.emplace_back(var_id, index);
    mb_->query.markOutputVariable(var_id, index);
}

namespace {
PointerOption<NNLPFactory>
    _option("marabou", [](const Arguments&) -> std::shared_ptr<NNLPFactory> {
        return std::make_shared<MarabouLPFactory>();
    });
} // namespace

} // namespace police

#else

namespace police {

MarabouLP::MarabouLP()
{
    POLICE_MISSING_DEPENDENCY("Marabou");
}

void MarabouLP::do_push_snapshot()
{
    POLICE_MISSING_DEPENDENCY("Marabou");
}

void MarabouLP::do_pop_snapshot()
{
    POLICE_MISSING_DEPENDENCY("Marabou");
}

void MarabouLP::do_clear()
{
    POLICE_MISSING_DEPENDENCY("Marabou");
}

void MarabouLP::do_add_variable(const VariableType&)
{
    POLICE_MISSING_DEPENDENCY("Marabou");
}

void MarabouLP::add_constraint(const linear_constraint_type&)
{
    POLICE_MISSING_DEPENDENCY("Marabou");
}

void MarabouLP::add_constraint(const relu_constraint_type&)
{
    POLICE_MISSING_DEPENDENCY("Marabou");
}

void MarabouLP::add_constraint(const max_constraint_type&)
{
    POLICE_MISSING_DEPENDENCY("Marabou");
}

void MarabouLP::add_constraint(const linear_constraint_disjunction_type&)
{
    POLICE_MISSING_DEPENDENCY("Marabou");
}

NNLP::Status MarabouLP::do_solve()
{
    POLICE_MISSING_DEPENDENCY("Marabou");
}

NNLP::model_type MarabouLP::get_model() const
{
    POLICE_MISSING_DEPENDENCY("Marabou");
}

MarabouLP::PreprocessorResult MarabouLP::preprocess() const
{
    POLICE_MISSING_DEPENDENCY("Marabou");
}

void MarabouLP::do_set_variable_upper_bound(size_t, real_t)
{
    POLICE_MISSING_DEPENDENCY("Marabou");
}

void MarabouLP::do_set_variable_lower_bound(size_t, real_t)
{
    POLICE_MISSING_DEPENDENCY("Marabou");
}

void MarabouLP::set_input_index(size_t, size_t)
{
    POLICE_MISSING_DEPENDENCY("Marabou");
}

void MarabouLP::set_output_index(size_t, size_t)
{
    POLICE_MISSING_DEPENDENCY("Marabou");
}

void MarabouLP::dump() const
{
    POLICE_MISSING_DEPENDENCY("Marabou");
}

} // namespace police

#endif
