#include "police/nnlp_encoders.hpp"

#include "police/base_types.hpp"
#include "police/nnlp.hpp"
#include "police/storage/variable_space.hpp"

#include <algorithm>

namespace police {

namespace {

constexpr real_t TOLERANCE = 0; // 1e-8;

template <bool Ub>
real_t compute_bound(
    const NNLP& lp,
    const LinearCombination<size_t, real_t>& constraint)
{
    real_t b = 0.;
    for (const auto& [var, coef] : constraint) {
        if ((!Ub && coef < 0.) || (Ub && coef > 0.)) {
            if (!lp.has_upper_bound(var)) {
                return Ub ? NNLP::NO_UB : NNLP::NO_LB;
            }
            b += lp.get_variable_upper_bound(var) * coef;
        } else {
            if (!lp.has_lower_bound(var)) {
                return Ub ? NNLP::NO_UB : NNLP::NO_LB;
            }
            b += lp.get_variable_lower_bound(var) * coef;
        }
    }
    return b;
}

} // namespace

real_t compute_lb(const NNLP& lp, const LinearCombination<size_t, real_t>& expr)
{
    return compute_bound<false>(lp, expr);
}

real_t compute_ub(const NNLP& lp, const LinearCombination<size_t, real_t>& expr)
{
    return compute_bound<true>(lp, expr);
}

namespace {

real_t calculate_min_negative_relaxation(
    const NNLP& lp,
    const LinearConstraint& constraint)
{
    auto b = compute_ub(lp, constraint);
    if (b == NNLP::NO_UB) {
        return NNLP::NO_LB;
    }
    if (b <= constraint.rhs) {
        return 0;
    }
    // round up to avoid precision issues
    return -std::ceil(b - constraint.rhs);
}

real_t calculate_max_positive_relaxation(
    const NNLP& lp,
    const LinearConstraint& constraint)
{
    auto lb = compute_lb(lp, constraint);
    if (lb == NNLP::NO_LB) {
        return NNLP::NO_UB;
    }
    if (lb >= constraint.rhs) {
        return 0;
    }
    // round up to avoid precision issues
    return std::ceil(constraint.rhs - lb);
}

void encode_conjunct(
    NNLP& lp,
    LinearConstraint& conj,
    LinearConstraint constraint)
{
    switch (constraint.type) {
    case LinearConstraint::LESS_EQUAL: {
        assert(calculate_min_negative_relaxation(lp, constraint) < 0.);
        const auto neg = lp.add_variable(BoundedRealType(
            calculate_min_negative_relaxation(lp, constraint),
            0));
        constraint.insert(neg, 1.);
        conj.insert(neg, -1.);
        break;
    }
    case LinearConstraint::GREATER_EQUAL: {
        assert(calculate_max_positive_relaxation(lp, constraint) > 0.);
        const auto pos = lp.add_variable(BoundedRealType(
            0,
            calculate_max_positive_relaxation(lp, constraint)));
        constraint.insert(pos, 1.);
        conj.insert(pos, 1.);
        break;
    }
    case LinearConstraint::EQUAL: {
        const auto minr = calculate_min_negative_relaxation(lp, constraint);
        if (minr < 0.) {
            const auto neg = lp.add_variable(BoundedRealType(minr, 0));
            constraint.insert(neg, 1.);
            conj.insert(neg, -1.);
        }
        const auto maxr = calculate_max_positive_relaxation(lp, constraint);
        if (maxr > 0.) {
            const auto pos = lp.add_variable(BoundedRealType(0, maxr));
            constraint.insert(pos, 1.);
            conj.insert(pos, 1.);
        }
        break;
    }
    }
    lp.add_constraint(std::move(constraint));
}

void encode_conjunction(
    NNLP& lp,
    LinearConstraintDisjunction& disj,
    const LinearConstraintConjunction& conj)
{
    assert(!conj.empty());
    LinearConstraint constr(LinearConstraint::Type::LESS_EQUAL);
    std::for_each(conj.begin(), conj.end(), [&](const auto& c) {
        encode_conjunct(lp, constr, c);
    });
    disj.push_back(std::move(constr));
}

} // namespace

PresolveStatus
compare_to_rhs(LinearConstraint::Type type, real_t lb, real_t ub, real_t rhs)
{
    switch (type) {
    case LinearConstraint::LESS_EQUAL: {
        if (ub != NNLP::NO_UB && ub + TOLERANCE <= rhs) {
            return PresolveStatus::SAT;
        }
        if (lb != NNLP::NO_LB && lb - TOLERANCE > rhs) {
            return PresolveStatus::UNSAT;
        }
        break;
    }
    case LinearConstraint::GREATER_EQUAL: {
        if (lb != NNLP::NO_LB && lb - TOLERANCE >= rhs) {
            return PresolveStatus::SAT;
        }
        if (ub != NNLP::NO_UB && ub + TOLERANCE < rhs) {
            return PresolveStatus::UNSAT;
        }
        break;
    }
    case LinearConstraint::EQUAL: {
        if (lb != NNLP::NO_LB && ub != NNLP::NO_UB &&
            std::abs(rhs - lb) <= TOLERANCE &&
            std::abs(rhs - ub) <= TOLERANCE) {
            return PresolveStatus::SAT;
        }
        if (lb != NNLP::NO_LB && lb - TOLERANCE > rhs) {
            return PresolveStatus::UNSAT;
        }
        if (ub != NNLP::NO_UB && ub + TOLERANCE < rhs) {
            return PresolveStatus::UNSAT;
        }

        break;
    }
    }
    return PresolveStatus::UNKNOWN;
}

PresolveStatus presolve(const NNLP& lp, const LinearConstraint& constraint)
{
    const auto lb = compute_lb(lp, constraint);
    const auto ub = compute_ub(lp, constraint);
    return compare_to_rhs(constraint.type, lb, ub, constraint.rhs);
}

PresolveStatus presolve(const NNLP& lp, LinearConstraintDisjunction& disj)
{
    auto j = 0u;
    for (auto i = 0u; i < disj.size(); ++i) {
        const PresolveStatus status = presolve(lp, disj[i]);
        switch (status) {
        case PresolveStatus::SAT: return PresolveStatus::SAT;
        case PresolveStatus::UNKNOWN:
            if (i != j) {
                disj[j] = std::move(disj[i]);
            }
            ++j;
            break;
        default: break;
        }
    }
    if (j == 0u) {
        return PresolveStatus::UNSAT;
    }
    disj.erase(disj.begin() + j, disj.end());
    return PresolveStatus::UNKNOWN;
}

PresolveStatus presolve(const NNLP& lp, LinearConstraintConjunction& conj)
{
    auto j = 0u;
    for (auto i = 0u; i < conj.size(); ++i) {
        const auto status = presolve(lp, conj[i]);
        switch (status) {
        case PresolveStatus::UNSAT: return PresolveStatus::UNSAT;
        case PresolveStatus::UNKNOWN:
            if (i != j) {
                conj[j] = std::move(conj[i]);
            }
            ++j;
            break;
        default: break;
        }
    }
    if (j == 0u) {
        return PresolveStatus::SAT;
    }
    conj.erase(conj.begin() + j, conj.end());
    return PresolveStatus::UNKNOWN;
}

PresolveStatus presolve(const NNLP& lp, LinearCondition& cond)
{
    auto j = 0u;
    for (auto i = 0u; i < cond.size(); ++i) {
        const auto status = presolve(lp, cond[i]);
        switch (status) {
        case PresolveStatus::SAT: return PresolveStatus::SAT;
        case PresolveStatus::UNKNOWN:
            if (i != j) {
                cond[j] = std::move(cond[i]);
            }
            ++j;
            break;
        default: break;
        }
    }
    if (j == 0u) {
        return PresolveStatus::UNSAT;
    }
    cond.erase(cond.begin() + j, cond.end());
    return PresolveStatus::UNKNOWN;
}

bool encode_linear_condition(NNLP& lp, const LinearCondition& cond)
{
    assert(!cond.empty());
    LinearCondition presolved(cond);
    const auto status = presolve(lp, presolved);
    if (status == PresolveStatus::SAT) {
        return true;
    }
    assert(status != PresolveStatus::UNSAT);
    if (status == PresolveStatus::UNSAT) {
        return false;
    }
    assert(!presolved.empty());
    if (presolved.size() == 1u) {
        std::for_each(
            presolved[0].begin(),
            presolved[0].end(),
            [&](const LinearConstraint& c) { lp.add_constraint(c); });
    } else {
        LinearConstraintDisjunction disj;
        disj.reserve(presolved.size());
        std::for_each(
            presolved.begin(),
            presolved.end(),
            [&](const auto& conj) { encode_conjunction(lp, disj, conj); });
        lp.add_constraint(std::move(disj));
    }
    return true;
}

} // namespace police
