#include "police/nnlp_bnb.hpp"
#include "police/arguments.hpp"
#include "police/nnlp_factory.hpp"
#include "police/nnlp_wrapper.hpp"
#include "police/option.hpp"
#include "police/option_parser.hpp"
#include "police/storage/variable_space.hpp"

#include <cmath>
#include <cstdlib>
#include <limits>

namespace police {

namespace {

class ModelWrapper {
public:
    ModelWrapper(NNLP::model_type model, const VariableSpace* vspace)
        : vspace_(vspace)
        , model(std::move(model))
    {
    }

    [[nodiscard]]
    Value get_value(size_t var_idx) const
    {
        return Value(
            static_cast<real_t>(model.get_value(var_idx)),
            vspace_->at(var_idx).type.value_type());
    }

    [[nodiscard]]
    size_t size() const
    {
        return model.size();
    }

    const VariableSpace* vspace_;
    NNLP::model_type model;
};

} // namespace

size_t NNLPBranchNBound::get_domain_size(size_t var) const
{
    if (!has_lower_bound(var) || !has_upper_bound(var)) {
        return std::numeric_limits<int_t>::max();
    }
    const int_t cur_lb = get_variable_lower_bound(var);
    const int_t cur_ub = get_variable_upper_bound(var);
    assert(cur_ub >= cur_lb);
    return cur_ub - cur_lb + 1;
}

std::pair<bool, size_t>
NNLPBranchNBound::find_split_variable(const NNLP::model_type& model) const
{
    bool found = false;
    size_t split_var = 0;
    std::pair<bool, size_t> best_value(1, std::numeric_limits<int_t>::max());
    for (const size_t& var : int_vars_) {
        assert(
            !has_lower_bound(var) ||
            model.get_value(var) >=
                static_cast<real_t>(get_variable_lower_bound(var)) -
                    error_margin_);
        assert(
            !has_upper_bound(var) ||
            model.get_value(var) <=
                static_cast<real_t>(
                    get_variable_upper_bound(var) + error_margin_));
        const size_t domain_size = get_domain_size(var);
        if (domain_size != 1u) {
            const real_t val = static_cast<real_t>(model.get_value(var));
            const real_t rnd = std::round(val);
            const bool is_int_sol = std::abs(val - rnd) < error_margin_;
            const std::pair<bool, size_t> value(is_int_sol, domain_size);
            if (value < best_value) {
                assert(!is_int_sol || value.first);
                split_var = var;
                best_value = value;
                found = !is_int_sol;
            }
        }
    }
    return {found, split_var};
}

void NNLPBranchNBound::push_branches(
    vector<Branch>& branches,
    size_t var,
    const model_type& model,
    size_t back_ref) const
{
    const real_t cur_lb = get_variable_lower_bound(var);
    const real_t cur_ub = get_variable_upper_bound(var);
    const real_t val = static_cast<real_t>(model.get_value(var));
    // note: taking min to take into account possible rounding errors
    const real_t new_lb = std::min(std::ceil(val), cur_ub);
    assert(new_lb > cur_lb);
    // as for new_lb: taking max to take into account possible rounding errors
    const real_t new_ub = std::max(std::floor(val), cur_lb);
    assert(new_ub < cur_ub);
    // branch smaller interval first
    const int_t lb_size =
        cur_ub == NO_UB ? std::numeric_limits<int_t>::max() : (cur_ub - new_lb);
    const int_t ub_size =
        cur_lb == NO_LB ? std::numeric_limits<int>::max() : (new_ub - cur_lb);
    // branches is a stack, i.e., push branch to consider next *last*
    if (lb_size <= ub_size) {
        branches.emplace_back(cur_lb, cur_ub, cur_lb, new_ub, var, back_ref);
        branches.emplace_back(cur_lb, cur_ub, new_lb, cur_ub, var, back_ref);
    } else {
        branches.emplace_back(cur_lb, cur_ub, new_lb, cur_ub, var, back_ref);
        branches.emplace_back(cur_lb, cur_ub, cur_lb, new_ub, var, back_ref);
    }
}

NNLP::model_type NNLPBranchNBound::get_model() const
{
    return ModelWrapper(last_model_, &non_relaxed_vspace_);
}

NNLP::Status NNLPBranchNBound::solve()
{
    auto status = NNLPWrapper::solve();
    if (status != NNLP::Status::SAT) {
        return status;
    }
    vector<Branch> branches;
    branches.reserve(2 * int_vars_.size());
    for (; status == NNLP::Status::SAT;) {
        last_model_ = NNLPWrapper::get_model();
        auto split = find_split_variable(last_model_);
        if (!split.first) {
            // integer solution
            // go branch stack back up in order to reset all variable bounds
            while (!branches.empty()) {
                const auto& b = branches.back();
                assert(b.flag);
                set_variable_upper_bound(b.var, b.old_ub);
                set_variable_lower_bound(b.var, b.old_lb);
                branches.resize(b.back_ref);
            }
            return NNLP::Status::SAT;
        }
        push_branches(branches, split.second, last_model_, branches.size());
        // try to find next model
        for (; !branches.empty();) {
            auto& b = branches.back();
            // mark as expanded
            b.flag = true;
            // set var bounds
            set_variable_lower_bound(b.var, b.new_lb);
            set_variable_upper_bound(b.var, b.new_ub);
            // check if solvable
            status = NNLPWrapper::solve();
            if (status == NNLP::Status::SAT) {
                // found next model
                break;
            }
            // backtrack while resetting enforced var bounds
            for (; !branches.empty() && branches.back().flag;
                 branches.pop_back()) {
                const auto& b = branches.back();
                set_variable_upper_bound(b.var, b.old_ub);
                set_variable_lower_bound(b.var, b.old_lb);
            }
        }
    }
    return status;
}

void NNLPBranchNBound::clear()
{
    int_vars_.clear();
    NNLPWrapper::clear();
}

void NNLPBranchNBound::pop_snapshot()
{
    NNLPWrapper::pop_snapshot();
    const auto num_vars = num_variables();
    while (!int_vars_.empty() && int_vars_.back() >= num_vars) {
        int_vars_.pop_back();
    }
    non_relaxed_vspace_.erase(
        non_relaxed_vspace_.begin() + num_vars,
        non_relaxed_vspace_.end());
}

size_t NNLPBranchNBound::add_variable(const VariableType& var_type)
{
    non_relaxed_vspace_.add_variable("", var_type);
    const size_t var_idx = NNLPWrapper::add_variable(var_type.relax());
    if (!var_type.is_real() && !var_type.is_bounded_real()) {
        int_vars_.push_back(var_idx);
    }
    return var_idx;
}

bool NNLPBranchNBound::has_integer_variable() const
{
    return int_vars_.size() > 0u;
}

NNLPBranchNBoundFactory::NNLPBranchNBoundFactory(NNLPFactory* sub_factory)
    : sub_(sub_factory)
{
}

NNLP* NNLPBranchNBoundFactory::make() const
{
    return new NNLPBranchNBound(sub_->make_unique());
}

namespace {
PointerOption<NNLPFactory> _option(
    "bnb",
    [](const Arguments& args) -> std::shared_ptr<NNLPFactory> {
        return std::make_shared<NNLPBranchNBoundFactory>(
            args.get<std::shared_ptr<NNLPFactory>>("lp").get());
    },
    [](ArgumentsDefinition& args) {
        args.add_ptr_argument<NNLPFactory>("lp");
    });
} // namespace

} // namespace police
