#include "police/nnlp.hpp"
#include "police/storage/variable_space.hpp"

#include <algorithm>
#include <type_traits>
#include <variant>

namespace police {

void NNLP::tighten_variable_bounds(size_t var_ref, real_t lb, real_t ub)
{
    const auto cur_lb = get_variable_lower_bound(var_ref);
    const auto cur_ub = get_variable_upper_bound(var_ref);
    if (cur_lb < lb) {
        set_variable_lower_bound(var_ref, lb);
    }
    if (cur_ub > ub) {
        set_variable_upper_bound(var_ref, ub);
    }
}

namespace {
bool is_integer_type(const VariableType& type)
{
    return 0u | type.is_bool() | type.is_int() | type.is_bounded_int();
}

} // namespace

size_t NNLP::add_variables(const VariableSpace& vspace)
{
    const auto var_id = num_variables();
    std::for_each(vspace.begin(), vspace.end(), [&](auto&& var) {
        add_variable(var.type);
    });
    return var_id;
}

NNLPBase::NNLPBase()
    : chg_log_(1)
{
}

void NNLPBase::push_snapshot()
{
    chg_log_.emplace_back(num_variables(), num_integer_vars_);
    do_push_snapshot();
}

void NNLPBase::pop_snapshot()
{
    const auto& chgs = chg_log_.back();
    num_integer_vars_ = chgs.num_int_vars;
    for (const auto& [var, bound] : chgs.lbs) {
        var_lbs_[var] = bound;
    }
    var_lbs_.resize(chgs.num_vars);
    for (const auto& [var, bound] : chgs.ubs) {
        var_ubs_[var] = bound;
    }
    var_ubs_.resize(chgs.num_vars);
    vspace_.erase(vspace_.begin() + chgs.num_vars, vspace_.end());
    chg_log_.pop_back();
    do_pop_snapshot();
    // for (const auto& [var, bound] : chgs.lbs) {
    //     do_set_variable_lower_bound(var, bound);
    // }
    // for (const auto& [var, bound] : chgs.ubs) {
    //     do_set_variable_upper_bound(var, bound);
    // }
}

void NNLPBase::clear()
{
    var_lbs_.clear();
    var_ubs_.clear();
    chg_log_.clear();
    chg_log_.resize(1);
    vspace_.clear();
    num_integer_vars_ = 0;
    do_clear();
}

void NNLPBase::set_variable_upper_bound(size_t var_ref, real_t ub)
{
    notify_variable_upper_bound(var_ref, ub);
    do_set_variable_upper_bound(var_ref, ub);
}

void NNLPBase::set_variable_lower_bound(size_t var_ref, real_t lb)
{
    notify_variable_lower_bound(var_ref, lb);
    do_set_variable_lower_bound(var_ref, lb);
}

void NNLPBase::notify_variable_upper_bound(size_t var_ref, real_t ub)
{
    auto bnd = chg_log_.back().ubs.emplace(var_ref, var_ubs_[var_ref]);
    if (bnd.first->second == ub) {
        chg_log_.back().ubs.erase(bnd.first);
    }
    var_ubs_[var_ref] = ub;
}

void NNLPBase::notify_variable_lower_bound(size_t var_ref, real_t lb)
{
    auto bnd = chg_log_.back().lbs.emplace(var_ref, var_lbs_[var_ref]);
    if (bnd.first->second == lb) {
        chg_log_.back().lbs.erase(bnd.first);
    }
    var_lbs_[var_ref] = lb;
}

const VariableSpace& NNLPBase::get_variable_space() const
{
    return vspace_;
}

bool NNLPBase::has_integer_variable() const
{
    return num_integer_vars_ > 0;
}

size_t NNLPBase::num_variables() const
{
    return vspace_.size();
}

size_t NNLPBase::add_variable(const VariableType& var_type)
{
    const auto var_id = vspace_.size();
    const auto unb_type = var_type.unbounded();
    num_integer_vars_ += is_integer_type(unb_type);
    vspace_.add_variable("", unb_type);
    std::visit(
        [&](auto&& t) {
            using T = std::decay_t<decltype(t)>;
            if constexpr (std::is_same_v<T, BoolType>) {
                var_lbs_.push_back(0);
                var_ubs_.push_back(1);
            } else if constexpr (
                std::is_same_v<T, IntegerType> || std::is_same_v<T, RealType>) {
                var_lbs_.push_back(NO_LB);
                var_ubs_.push_back(NO_UB);
            } else {
                if (t.is_lower_bounded()) {
                    var_lbs_.push_back(t.lower_bound);
                } else {
                    var_lbs_.push_back(NO_LB);
                }
                if (t.is_upper_bounded()) {
                    var_ubs_.push_back(t.upper_bound);
                } else {
                    var_ubs_.push_back(NO_UB);
                }
            }
        },
        var_type);
    do_add_variable(unb_type);
    if (has_lower_bound(var_id)) {
        set_variable_lower_bound(var_id, var_lbs_[var_id]);
    }
    if (has_upper_bound(var_id)) {
        set_variable_upper_bound(var_id, var_ubs_[var_id]);
    }
    return var_id;
}

real_t NNLPBase::get_variable_lower_bound(size_t var_ref) const
{
    assert(var_ref < var_lbs_.size());
    return var_lbs_[var_ref];
}

real_t NNLPBase::get_variable_upper_bound(size_t var_ref) const
{
    assert(var_ref < var_ubs_.size());
    return var_ubs_[var_ref];
}

void NNLPBase::add_assumption(const linear_constraint_type& constraint)
{
    has_assumptions_ = true;
    ass_constraints_.push_back(constraint);
}

NNLP::Status NNLPBase::solve()
{
    if (has_assumptions_) {
        const auto result = do_solve(ass_constraints_);
        ass_constraints_.clear();
        has_assumptions_ = false;
        return result;
    } else {
        return do_solve();
    }
}

NNLP::Status
NNLPBase::do_solve(const vector<linear_constraint_type>& ass_constraints)
{
    push_snapshot();
    for (const auto& constraint : ass_constraints) {
        add_constraint(constraint);
    }
    const auto result = do_solve();
    pop_snapshot();
    return result;
}

} // namespace police
