#include "police/nnlp_lp.hpp"
#include "police/arguments.hpp"
#include "police/lp.hpp"
#include "police/lp_factory.hpp"
#include "police/macros.hpp"
#include "police/nnlp.hpp"
#include "police/nnlp_encoders.hpp"
#include "police/nnlp_factory.hpp"
#include "police/option.hpp"
#include "police/option_parser.hpp"
#include "police/storage/variable_space.hpp"

#include <iterator>
#include <type_traits>
#include <variant>

namespace police {

NNLPLP::NNLPLP(std::unique_ptr<LP> lp)
    : lp_(std::move(lp))
{
    lp_->set_sense(LPOptimizationKind::MAXIMIZE);
}

void NNLPLP::add_constraint(const linear_constraint_type& constraint)
{
    cleanup_last_solver_call();
    lp_->add_constraint(constraint);
}

void NNLPLP::add_constraint(const relu_constraint_type& constraint)
{
    cleanup_last_solver_call();
    lp_->add_constraint(MaxConstraint({constraint.x}, constraint.y, 0.));
}

void NNLPLP::add_constraint(const max_constraint_type& constraint)
{
    cleanup_last_solver_call();
    lp_->add_constraint(constraint);
}

void NNLPLP::add_constraint(
    const linear_constraint_disjunction_type& constraint)
{
    cleanup_last_solver_call();
    lp_->add_constraint(constraint);
}

NNLP::model_type NNLPLP::get_model() const
{
    return lp_->get_model();
}

void NNLPLP::dump() const
{
    lp_->dump(std::cout);
}

void NNLPLP::do_push_snapshot()
{
    cleanup_last_solver_call();
    lp_->push_snapshot();
}

void NNLPLP::do_pop_snapshot()
{
    cleanup_last_solver_call();
    lp_->pop_snapshot();
}

void NNLPLP::do_clear()
{
    POLICE_RUNTIME_ERROR("lp doesn't support resetting");
}

void NNLPLP::do_add_variable(const VariableType& var_type)
{
    cleanup_last_solver_call();
    LPVariable var;
    std::visit(
        [&](auto&& t) {
            using T = std::decay_t<decltype(t)>;
            if constexpr (std::is_same_v<T, BoolType>) {
                var.type = LPVariable::Type::BOOL;
                var.lower_bound = 0;
                var.upper_bound = 1;
            } else if constexpr (std::is_same_v<T, IntegerType>) {
                var.type = LPVariable::Type::INT;
            } else if constexpr (std::is_same_v<T, BoundedIntType>) {
                var.type = LPVariable::Type::INT;
                if (t.is_lower_bounded()) {
                    var.lower_bound = t.lower_bound;
                }
                if (t.is_upper_bounded()) {
                    var.upper_bound = t.upper_bound;
                }
            } else if constexpr (std::is_same_v<T, BoundedRealType>) {
                if (t.is_lower_bounded()) {
                    var.lower_bound = t.lower_bound;
                }
                if (t.is_upper_bounded()) {
                    var.upper_bound = t.upper_bound;
                }
            }
        },
        var_type);
    lp_->add_variable(var);
}

void NNLPLP::do_set_variable_upper_bound(size_t, real_t)
{
    // lp_->set_variable_upper_bound(var_ref, ub);
}

void NNLPLP::do_set_variable_lower_bound(size_t, real_t)
{
    // lp_->set_variable_lower_bound(var_ref, lb);
}

void NNLPLP::set_variable_bounds()
{
    assert(!dirty_state_);
    for (auto var = 0u; var < num_variables(); ++var) {
        if (has_lower_bound(var)) {
            LinearConstraint c(LinearConstraint::GREATER_EQUAL);
            c.insert(var, 1.);
            c.rhs = get_variable_lower_bound(var);
            lp_->add_constraint(c);
        }
        if (has_upper_bound(var)) {
            LinearConstraint c(LinearConstraint::LESS_EQUAL);
            c.insert(var, 1.);
            c.rhs = get_variable_upper_bound(var);
            lp_->add_constraint(c);
        }
    }
}

void NNLPLP::cleanup_last_solver_call()
{
    if (dirty_state_) {
        lp_->pop_snapshot();
        unsat_core_.clear();
        dirty_state_ = false;
    }
}

NNLP::Status NNLPLP::do_solve()
{
    return do_solve({});
}

NNLP::Status NNLPLP::do_solve(const vector<linear_constraint_type>& constraints)
{
    cleanup_last_solver_call();
    lp_->push_snapshot();
    set_variable_bounds();
    dirty_state_ = true;
    const auto status = lp_->solve(constraints);
    switch (status) {
    case LPStatus::OPTIMAL: [[fallthrough]];
    case LPStatus::SOLVABLE: return NNLP::SAT;
    case LPStatus::UNBOUNDED: [[fallthrough]];
    case LPStatus::INFEASIBLE: {
        vector<linear_constraint_type> core = lp_->get_unsat_core();
        std::transform(
            core.begin(),
            core.end(),
            std::back_inserter(unsat_core_),
            [](auto& c) { return std::move(c); });
        return NNLP::UNSAT;
    }
    }
    POLICE_UNREACHABLE();
}

namespace {
PointerOption<NNLPFactory> _opt(
    "lp",
    [](const Arguments& args) -> std::shared_ptr<NNLPFactory> {
        return std::make_shared<NNLPLPFactory>(
            args.get<std::shared_ptr<LPFactory>>("lp").get(),
            args.get<bool>("preprocess"));
    },
    [](ArgumentsDefinition& defs) {
        defs.add_ptr_argument<LPFactory>("lp", "", "gurobi");
        defs.add_argument<bool>("preprocess", "", "false");
    });
} // namespace

} // namespace police

#undef USE_INDICATOR
