#include "police/_bits/z3_solver.hpp"

#if POLICE_Z3

#include "police/_bits/z3_model.hpp"
#include "police/macros.hpp"

#include <memory>
#include <z3++.h>

namespace police {

Z3Solver::Z3Solver()
    : env_()
    , last_model_(nullptr)
    , solver_(std::make_unique<z3::solver>(*env_.context()))
{
}

Z3Solver::~Z3Solver()
{
    // destruct objects in specific order
    last_model_ = nullptr;
    solver_ = nullptr;
}

Z3Environment& Z3Solver::get_environment()
{
    return env_;
}

const Z3Environment& Z3Solver::get_environment() const
{
    return env_;
}

namespace {
void add_variable_bounds(
    z3::solver& solver,
    z3::expr var_expr,
    const VariableType& type)
{
    std::visit(
        [&](auto&& t) {
            using T = std::decay_t<decltype(t)>;
            if constexpr (std::is_same_v<T, BoundedIntType>) {
                solver.add(var_expr >= t.lower_bound);
                solver.add(var_expr <= t.upper_bound);
            } else if constexpr (std::is_same_v<T, BoundedRealType>) {
                solver.add(
                    var_expr >=
                    solver.ctx().real_val(
                        std::to_string(static_cast<real_t>(t.lower_bound))
                            .c_str()));
                solver.add(
                    var_expr <=
                    solver.ctx().real_val(
                        std::to_string(static_cast<real_t>(t.upper_bound))
                            .c_str()));
            }
        },
        type);
}
} // namespace

void Z3Solver::add_variable(const VariableType& type)
{
    const auto& var = env_.add_variable(type);
    add_variable_bounds(*solver_, var, type);
}

void Z3Solver::add_constraint(const expressions::Expression& expr)
{
    auto z3_expr = env_.to_z3_expression(expr);
    solver_->add(std::move(z3_expr));
}

void Z3Solver::push_snapshot()
{
    env_.push_snapshot();
    solver_->push();
}

void Z3Solver::pop_snapshot()
{
    solver_->pop();
    env_.pop_snapshot();
}

bool Z3Solver::check()
{
    const auto result = solver_->check();
    switch (result) {
    case z3::check_result::sat: return true;
    case z3::check_result::unsat: return false;
    default: POLICE_RUNTIME_ERROR("unknown z3 check status " << result);
    }
    POLICE_UNREACHABLE();
}

bool Z3Solver::check(const vector<expressions::Expression>& assumptions)
{
    z3::expr_vector vec(*env_.context());
    for (const auto& expr : assumptions) {
        vec.push_back(env_.to_z3_expression(expr));
    }
    const auto result = solver_->check(vec);
    switch (result) {
    case z3::check_result::sat: return true;
    case z3::check_result::unsat: return false;
    default: POLICE_RUNTIME_ERROR("unknown z3 check status " << result);
    }
    POLICE_UNREACHABLE();
}

Z3Model Z3Solver::get_model() const
{
    last_model_.reset(new z3::model(solver_->get_model()));
    return Z3Model(&env_, last_model_.get());
}

vector<expressions::Expression> Z3Solver::get_unsat_core() const
{
    vector<expressions::Expression> result;
    const auto core = solver_->unsat_core();
    result.reserve(core.size());
    std::transform(
        core.begin(),
        core.end(),
        std::back_inserter(result),
        [&](const z3::expr& expr) { return env_.from_z3_expression(expr); });
    return result;
}

void Z3Solver::dump(std::ostream& out)
{
    out << solver_->to_smt2() << std::endl;
}

} // namespace police

#endif
