#include "police/nnlp_smt.hpp"
#include "police/expressions/expression.hpp"
#include "police/linear_condition.hpp"
#include "police/macros.hpp"
#include "police/nnlp.hpp"
#include "police/nnlp_factory.hpp"
#include "police/option.hpp"
#include "police/option_parser.hpp"
#include "police/smt_factory.hpp"

#include <iterator>
#include <memory>

namespace police {

NNLPSMT::NNLPSMT(std::unique_ptr<SMT> smt)
    : smt_(std::move(smt))
{
}

void NNLPSMT::add_constraint(const linear_constraint_type& constraint)
{
    remove_bounds_from_smt();
    smt_->add_constraint(constraint.as_expression());
}

void NNLPSMT::add_constraint(const relu_constraint_type& constraint)
{
    remove_bounds_from_smt();
    smt_->add_constraint(constraint.as_max_expression());
}

void NNLPSMT::add_constraint(const max_constraint_type& constraint)
{
    remove_bounds_from_smt();
    smt_->add_constraint(constraint.as_expression());
}

void NNLPSMT::add_constraint(
    const linear_constraint_disjunction_type& constraint)
{
    remove_bounds_from_smt();
    smt_->add_constraint(constraint.as_expression());
}

NNLP::model_type NNLPSMT::get_model() const
{
    return smt_->get_model();
}

void NNLPSMT::do_push_snapshot()
{
    remove_bounds_from_smt();
    smt_->push_snapshot();
}

void NNLPSMT::do_pop_snapshot()
{
    remove_bounds_from_smt();
    smt_->pop_snapshot();
}

void NNLPSMT::do_clear()
{
    needs_a_clean_ = false;
    smt_->clear();
}

void NNLPSMT::do_add_variable(const VariableType& var_type)
{
    remove_bounds_from_smt();
    smt_->add_variable("", var_type);
}

void NNLPSMT::do_set_variable_upper_bound(size_t, real_t)
{
}

void NNLPSMT::do_set_variable_lower_bound(size_t, real_t)
{
}

void NNLPSMT::add_bounds_to_smt() const
{
    assert(!needs_a_clean_);
    smt_->push_snapshot();
    for (auto var = 0u; var < num_variables(); ++var) {
        if (has_lower_bound(var)) {
            smt_->add_constraint(
                expressions::Variable(var) >=
                Value(get_variable_lower_bound(var)));
        }
        if (has_upper_bound(var)) {
            smt_->add_constraint(
                expressions::Variable(var) <=
                Value(get_variable_upper_bound(var)));
        }
    }
    needs_a_clean_ = true;
}

void NNLPSMT::remove_bounds_from_smt() const
{
    if (needs_a_clean_) {
        smt_->pop_snapshot();
        needs_a_clean_ = false;
    }
}

NNLP::Status NNLPSMT::do_solve()
{
    remove_bounds_from_smt();
    add_bounds_to_smt();
    const auto status = smt_->solve();
    switch (status) {
    case SMT::Status::UNSAT: return NNLP::Status::UNSAT;
    case SMT::Status::SAT: return NNLP::Status::SAT;
    }
    POLICE_UNREACHABLE();
}

namespace {
template <typename T>
void as_expressions(vector<expressions::Expression>& out, const vector<T>& in)
{
    std::transform(
        in.begin(),
        in.end(),
        std::back_inserter(out),
        [](const auto& e) { return e.as_expression(); });
}
} // namespace

NNLP::Status
NNLPSMT::do_solve(const vector<linear_constraint_type>& ass_constraints)
{
    vector<expressions::Expression> assumptions;
    assumptions.reserve(ass_constraints.size());
    as_expressions(assumptions, ass_constraints);
    remove_bounds_from_smt();
    add_bounds_to_smt();
    const auto status = smt_->solve(assumptions);
    switch (status) {
    case SMT::Status::UNSAT: return NNLP::Status::UNSAT;
    case SMT::Status::SAT: return NNLP::Status::SAT;
    }
    POLICE_UNREACHABLE();
}

bool NNLPSMT::supports_unsolvable_core() const
{
    return true;
}

LinearConstraintConjunction NNLPSMT::get_unsolvable_core() const
{
    auto core = smt_->get_unsat_core();
    LinearConstraintConjunction result;
    for (const auto& expr : core) {
        auto cond = LinearCondition::from_expression(expr);
        assert(cond.size() == 1u);
        result &= std::move(cond[0]);
    }
    return result;
}

void NNLPSMT::dump() const
{
    remove_bounds_from_smt();
    add_bounds_to_smt();
    smt_->dump(std::cout);
    remove_bounds_from_smt();
}

// NNLPSMTWithPreprocessing::NNLPSMTWithPreprocessing(std::unique_ptr<SMT> smt)
//     : smt_(std::move(smt))
// {
// }
//
// void NNLPSMTWithPreprocessing::do_push_snapshot()
// {
//     smt_->push_snapshot();
// }
//
// void NNLPSMTWithPreprocessing::do_pop_snapshot()
// {
//     smt_->pop_snapshot();
// }
//
// void NNLPSMTWithPreprocessing::do_clear()
// {
//     smt_->clear();
// }
//
// void NNLPSMTWithPreprocessing::do_add_variable(
//     const VariableType& var_type)
// {
//     smt_->add_variable("", var_type);
// }
//
// void NNLPSMTWithPreprocessing::do_add_constraint(
//     const linear_constraint_type& constraint)
// {
//     smt_->add_constraint(constraint.as_expression());
// }
//
// void NNLPSMTWithPreprocessing::do_add_constraint(
//     const relu_constraint_type& constraint)
// {
//     smt_->add_constraint(constraint.as_max_expression());
// }
//
// void NNLPSMTWithPreprocessing::do_add_constraint(
//     const max_constraint_type& constraint)
// {
//     smt_->add_constraint(constraint.as_expression());
// }
//
// void NNLPSMTWithPreprocessing::do_add_constraint(
//     const linear_constraint_disjunction_type& constraint)
// {
//     smt_->add_constraint(constraint.as_expression());
// }
//
// NNLP::Status NNLPSMTWithPreprocessing::do_solve()
// {
//     const auto status = smt_->solve();
//     switch (status) {
//     case SMT::Status::UNSAT: return NNLP::Status::UNSAT;
//     case SMT::Status::SAT: return NNLP::Status::SAT;
//     }
//     POLICE_UNREACHABLE();
// }
//
// NNLP::Status NNLPSMTWithPreprocessing::do_solve(
//     const vector<linear_constraint_type>& ass_constraints)
// {
//     vector<expressions::Expression> assumptions;
//     assumptions.reserve(ass_constraints.size());
//     as_expressions(assumptions, ass_constraints);
//     const auto status = smt_->solve(assumptions);
//     switch (status) {
//     case SMT::Status::UNSAT: return NNLP::Status::UNSAT;
//     case SMT::Status::SAT: return NNLP::Status::SAT;
//     }
//     POLICE_UNREACHABLE();
// }
//
// NNLP::model_type NNLPSMTWithPreprocessing::get_model() const
// {
//     return smt_->get_model();
// }

namespace {
PointerOption<NNLPFactory> _option(
    "smt",
    [](const Arguments& args) -> std::shared_ptr<NNLPFactory> {
        return std::make_shared<NNLPSMTFactory>(
            args.get<std::shared_ptr<SMTFactory>>("smt").get(),
            args.get<bool>("preprocess"));
    },
    [](ArgumentsDefinition& defs) {
        defs.add_ptr_argument<SMTFactory>("smt", "", "z3");
        defs.add_argument<bool>("preprocess", "", "false");
    });
} // namespace
} // namespace police
