#include "police/verifiers/ic3/start_generator.hpp"
#include "police/expressions/expression.hpp"
#include "police/expressions/variable.hpp"
#include "police/verifiers/ic3/cube.hpp"

namespace police::ic3 {

StartGenerator::StartGenerator(
    size_t num_vars,
    std::shared_ptr<SMT> base_smt,
    vector<std::pair<size_t, Value>> aux_values)
    : smt_(std::move(base_smt))
    , aux_values_(std::move(aux_values))
    , num_vars_(num_vars)
{
    smt_->push_snapshot();
}

namespace {

expressions::Expression not_interval(size_t var_id, const Interval& iv)
{
    assert(iv.lb.get_type() == Value::Type::INT);
    expressions::Variable var(var_id);
    if (iv.has_ub() && iv.has_lb()) {
        if (iv.lb == iv.ub) {
            return expressions::not_equal(var, iv.lb);
        } else {
            return expressions::less_equal(var, iv.lb - Value(1)) ||
                   expressions::greater_equal(var, iv.ub + Value(1));
        }
    } else if (iv.has_lb()) {
        return expressions::less_equal(var, iv.lb - Value(1));
    } else {
        assert(iv.has_ub());
        return expressions::greater_equal(var, iv.ub + Value(1));
    }
}

expressions::Expression not_cube(const Cube& cube)
{
    assert(cube.size() > 0u);
    auto it = cube.begin();
    expressions::Expression result = not_interval(it->first, it->second);
    ++it;
    for (; it != cube.end(); ++it) {
        result = result || not_interval(it->first, it->second);
    }
    return result;
}

} // namespace

void StartGenerator::set_blocked(const Cube& cube)
{
    smt_->add_constraint(not_cube(cube));
}

void StartGenerator::clear()
{
    smt_->pop_snapshot();
    smt_->push_snapshot();
}

std::optional<flat_state> StartGenerator::operator()() const
{
    const auto status = smt_->solve();
    if (status == SMT::Status::SAT) {
        flat_state state(num_vars_);
        const auto model = smt_->get_model();
        assert(num_vars_ >= model.size());
        for (size_t var = 0; var < model.size(); ++var) {
            state[var] = model.get_value(var);
        }
        for (const auto& [var, val] : aux_values_) {
            state[var] = val;
        }
        return {std::move(state)};
    }
    return std::nullopt;
}

} // namespace police::ic3
