#include "police/verifiers/ic3/sat_based/sat_interface/nnlp.hpp"

#include "police/linear_condition.hpp"
#include "police/linear_constraint.hpp"
#include "police/variable_substitution.hpp"
#include "police/nnlp.hpp"
#include "police/verifiers/ic3/cube_utils.hpp"

#include <cassert>
#include <limits>

namespace police::ic3 {

namespace {

LinearConstraint
get_constraint(size_t var_id, real_t bound, LinearConstraint::Type type)
{
    LinearConstraint res(type);
    res.rhs = bound;
    res.insert(var_id, 1);
    return res;
}

LinearConstraint get_lb_constraint(size_t var_id, real_t bound)
{
    return get_constraint(var_id, bound, LinearConstraint::Type::GREATER_EQUAL);
}

LinearConstraint get_ub_constraint(size_t var_id, real_t bound)
{
    return get_constraint(var_id, bound, LinearConstraint::Type::LESS_EQUAL);
}

void get_negated_interval(
    LinearConstraintDisjunction& disj,
    const NNLP* lp,
    const Interval& i,
    size_t out_var)
{
    assert(
        i.lb.get_type() != Value::Type::REAL &&
        i.ub.get_type() != Value::Type::REAL);
    // x' < lb
    const auto cur_lb = lp->get_variable_lower_bound(out_var);
    if (static_cast<real_t>(i.lb) > cur_lb) {
        disj.push_back(
            get_ub_constraint(out_var, static_cast<int_t>(i.lb) - 1));
    }
    // ... or x' > ub
    const auto cur_ub = lp->get_variable_upper_bound(out_var);
    if (static_cast<real_t>(i.ub) < cur_ub) {
        disj.push_back(
            get_lb_constraint(out_var, static_cast<int_t>(i.ub) + 1));
    }
    // return if this is actually possible given the variable's bounds
}

LinearConstraintDisjunction get_negated_cube(
    const NNLP* lp,
    const Cube& cube,
    const vector<size_t>& out_vars)
{
    LinearConstraintDisjunction result;
    for (auto it = cube.begin(); it != cube.end(); ++it) {
        get_negated_interval(result, lp, it->second, out_vars[it->first]);
    }
    return result;
}

} // namespace

SatInterfaceNNLP::SatInterfaceNNLP(
    NNLP* lp,
    size_t frame_var,
    vector<size_t> input_vars,
    vector<size_t> output_vars)
    : input_vars_(std::move(input_vars))
    , output_vars_(std::move(output_vars))
    , lp_(lp)
    , frame_var_(frame_var)
    , unsolvable_frame_(std::numeric_limits<size_t>::max())
{
    assert(input_vars_.size() == output_vars_.size());
    assert(std::is_sorted(input_vars_.begin(), input_vars_.end()));
    var_refs_.resize(input_vars_.back() + 1);
    for (int i = input_vars_.size() - 1; i >= 0; --i) {
        var_refs_[input_vars_[i]] = i;
    }
    lp->push_snapshot();
}

std::pair<bool, size_t>
SatInterfaceNNLP::is_blocked(const Cube& cube, size_t frame_id)
{
    assert(frame_id > 0u);
    if (frame_id >= unsolvable_frame_) {
        unsat_core_.clear();
        return {true, frame_id};
    }
    return {is_blocked_by_frame(cube, frame_id), frame_id};
}

void SatInterfaceNNLP::set_blocked(const Cube& cube, size_t frame_id)
{
    assert(frame_id > 0u);
    LinearConstraintDisjunction clause =
        get_negated_cube(lp_, cube, output_vars_);
    clause.push_back(get_lb_constraint(frame_var_, frame_id + 2));
    lp_->add_constraint(clause);
}

void SatInterfaceNNLP::add_frame()
{
    // noop
}

void SatInterfaceNNLP::clear_frames()
{
    lp_->pop_snapshot();
    lp_->push_snapshot();
}

namespace {
LinearConstraint is_geq_constraint(size_t frame_var, int_t frame_val)
{
    LinearConstraint c(LinearConstraint::GREATER_EQUAL);
    c.insert(frame_var, 1.);
    c.rhs = frame_val;
    return c;
}
LinearConstraint is_leq_constraint(size_t frame_var, int_t frame_val)
{
    LinearConstraint c(LinearConstraint::LESS_EQUAL);
    c.insert(frame_var, 1.);
    c.rhs = frame_val;
    return c;
}
LinearConstraint is_equal_constraint(size_t frame_var, int_t frame_val)
{
    LinearConstraint c(LinearConstraint::EQUAL);
    c.insert(frame_var, 1.);
    c.rhs = frame_val;
    return c;
}
} // namespace

bool SatInterfaceNNLP::is_blocked_by_frame(const Cube& cube, size_t frame_id)
{
    // if (!negated_cube(lp_, cube, output_vars_)) {
    //     lp_->pop_snapshot();
    //     return true;
    // }
    lp_->push_snapshot();
    lp_->add_constraint(is_equal_constraint(frame_var_, frame_id));
    // lp_->add_constraint(get_negated_cube(lp_, cube, output_vars_));
    for (auto it = cube.begin(); it != cube.end(); ++it) {
        const auto& d = it->second;
        lp_->add_assumption(is_geq_constraint(input_vars_[it->first], d.lb));
        lp_->add_assumption(is_leq_constraint(input_vars_[it->first], d.ub));
    }
    const auto status = lp_->solve();
    if (status == NNLP::Status::UNSAT && supports_unsat_core()) {
        unsat_core_.clear();
        auto core = lp_->get_unsolvable_core();
        if (core.empty()) {
            assert(unsolvable_frame_ > frame_id);
            unsolvable_frame_ = frame_id;
        } else {
            insert_into_cube(
                unsat_core_,
                substitute_vars(std::move(core), var_refs_));
        }
    }
    lp_->pop_snapshot();
#ifndef NDEBUG
    if (unsolvable_frame_ == frame_id) {
        lp_->push_snapshot();
        lp_->add_constraint(is_equal_constraint(frame_var_, frame_id));
        assert(lp_->solve() == NNLP::Status::UNSAT);
        lp_->pop_snapshot();
    }
#endif
    return status == NNLP::Status::UNSAT;
}

bool SatInterfaceNNLP::supports_unsat_core() const
{
    return lp_->supports_unsolvable_core();
}

} // namespace police::ic3
