#include "police/_bits/z3_env.hpp"

#ifdef POLICE_Z3

#include "police/_bits/z3_expr_adapter.hpp"
#include "police/expressions/boolean_combination.hpp"
#include "police/expressions/expression.hpp"
#include "police/expressions/expressions.hpp"
#include "police/expressions/negation.hpp"
#include "police/macros.hpp"
#include "police/storage/variable_space.hpp"

#include <cassert>
#include <memory>
#include <string>
#include <type_traits>
#include <variant>
#include <z3++.h>

namespace police {

namespace {

std::string get_z3_var_name(size_t idx)
{
    return "x" + std::to_string(idx);
}

z3::expr
create_variable(z3::context& context, size_t var, const VariableType& type)
{
    return std::visit(
        [&](auto&& t) {
            using T = std::decay_t<decltype(t)>;
            if constexpr (std::is_same_v<T, BoolType>) {
                return context.bool_const(get_z3_var_name(var).c_str());
            } else if constexpr (
                std::is_same_v<T, IntegerType> ||
                std::is_same_v<T, BoundedIntType>) {
                return context.int_const(get_z3_var_name(var).c_str());
            } else {
                return context.real_const(get_z3_var_name(var).c_str());
            }
        },
        type);
}

} // namespace

Z3Environment::Z3Environment()
    : context_(std::make_unique<z3::context>())
{
}

Z3Environment::~Z3Environment()
{
    // destruct objects in this specific order to avoid use after free
    vars_.clear();
    exprs_.clear();
    context_ = nullptr;
}

z3::expr Z3Environment::to_z3_expression(const expressions::Expression& expr)
{
    assert(expr_id_.size() == exprs_.size());
    auto pos = expr_id_.insert(expr);
    if (pos.second) {
        exprs_.push_back(z3::expr(*context_));
        assert(pos.first->second < exprs_.size());
        Z3ExpressionAdapter converter(this);
        expr.base()->accept(converter);
        assert(converter.expression != nullptr);
        exprs_[pos.first->second] = (std::move(*(converter.expression)));
    }
    return exprs_[pos.first->second];
}

const z3::expr& Z3Environment::get_variable(size_t var_id) const
{
    assert(var_id < vars_.size());
    return vars_[var_id];
}

namespace {
vector<expressions::Expression>
from_z3_expression_args(const z3::expr& expr, const Z3Environment& env)
{
    vector<expressions::Expression> result;
    result.reserve(expr.num_args());
    for (auto i = 0u; i < expr.num_args(); ++i) {
        result.push_back(env.from_z3_expression(expr.arg(i)));
    }
    return result;
}

expressions::Expression
from_z3_fun_appl(const z3::expr& expr, const Z3Environment& env)
{
    switch (expr.decl().decl_kind()) {
    case Z3_decl_kind::Z3_OP_TRUE: return Value(true);
    case Z3_decl_kind::Z3_OP_FALSE: return Value(false);
    case Z3_decl_kind::Z3_OP_IFF: [[fallthrough]];
    case Z3_decl_kind::Z3_OP_EQ:
        assert(expr.num_args() == 2u);
        return expressions::equal(
            env.from_z3_expression(expr.arg(0)),
            env.from_z3_expression(expr.arg(1)));
    case Z3_decl_kind::Z3_OP_DISTINCT:
        assert(expr.num_args() == 2u);
        return expressions::not_equal(
            env.from_z3_expression(expr.arg(0)),
            env.from_z3_expression(expr.arg(1)));
    case Z3_decl_kind::Z3_OP_ITE:
        assert(expr.num_args() == 3u);
        return expressions::ite(
            env.from_z3_expression(expr.arg(0)),
            env.from_z3_expression(expr.arg(1)),
            env.from_z3_expression(expr.arg(2)));
    case Z3_decl_kind::Z3_OP_AND:
        return expressions::Conjunction(from_z3_expression_args(expr, env));
    case Z3_decl_kind::Z3_OP_OR:
        return expressions::Disjunction(from_z3_expression_args(expr, env));
    case Z3_decl_kind::Z3_OP_NOT:
        assert(expr.num_args() == 1u);
        return expressions::Negation(env.from_z3_expression(expr.arg(0)));
    case Z3_decl_kind::Z3_OP_LE:
        assert(expr.num_args() == 2u);
        return expressions::less_equal(
            env.from_z3_expression(expr.arg(0)),
            env.from_z3_expression(expr.arg(1)));
    case Z3_decl_kind::Z3_OP_GE:
        assert(expr.num_args() == 2u);
        return expressions::greater_equal(
            env.from_z3_expression(expr.arg(0)),
            env.from_z3_expression(expr.arg(1)));
    case Z3_decl_kind::Z3_OP_LT:
        assert(expr.num_args() == 2u);
        return expressions::less(
            env.from_z3_expression(expr.arg(0)),
            env.from_z3_expression(expr.arg(1)));
    case Z3_decl_kind::Z3_OP_GT:
        assert(expr.num_args() == 2u);
        return expressions::greater(
            env.from_z3_expression(expr.arg(0)),
            env.from_z3_expression(expr.arg(1)));
    case Z3_decl_kind::Z3_OP_ADD:
        assert(expr.num_args() == 2u);
        return env.from_z3_expression(expr.arg(0)) +
               env.from_z3_expression(expr.arg(1));
    case Z3_decl_kind::Z3_OP_SUB:
        assert(expr.num_args() == 2u);
        return env.from_z3_expression(expr.arg(0)) -
               env.from_z3_expression(expr.arg(1));
    case Z3_decl_kind::Z3_OP_UMINUS:
        assert(expr.num_args() == 1u);
        return Value(0) - env.from_z3_expression(expr.arg(0));
    case Z3_decl_kind::Z3_OP_TO_REAL:
        assert(expr.num_args() == 1u);
        return env.from_z3_expression(expr.arg(0));
    case Z3_decl_kind::Z3_OP_UNINTERPRETED:
        return expressions::Variable(env.get_variable_id(expr));
    default:
        POLICE_RUNTIME_ERROR(
            "z3 function application " << expr.to_string()
                                       << " of unknown kind " << expr.kind());
    }
}
} // namespace

expressions::Expression
Z3Environment::from_z3_expression(const z3::expr& expr) const
{
    switch (expr.kind()) {
    case Z3_ast_kind::Z3_NUMERAL_AST:
        return expr.is_real() ? Value(static_cast<real_t>(expr.as_double()))
                              : Value(static_cast<int_t>(expr.as_int64()));
    case Z3_ast_kind::Z3_VAR_AST:
        return expressions::Variable(get_variable_id(expr));
    case Z3_ast_kind::Z3_APP_AST: return from_z3_fun_appl(expr, *this);
    default:
        POLICE_RUNTIME_ERROR(
            "z3 expression " << expr.to_string() << " of unknown kind "
                             << static_cast<int>(expr.kind()));
    }
}

const z3::expr& Z3Environment::add_variable(const VariableType& type)
{
    vars_.push_back(create_variable(*context_, var_counter_++, type));
    assert(!z3_id_to_var_.count(vars_.back().id()));
    z3_id_to_var_[vars_.back().id()] = vars_.size() - 1;
    vspace_.add_variable("", type);
    return vars_.back();
}

size_t Z3Environment::get_variable_id(const z3::expr& expr) const
{
    assert(z3_id_to_var_.count(expr.id()));
    return z3_id_to_var_.find(expr.id())->second;
}

z3::expr Z3Environment::add_constant(const Value& value)
{
    assert(expr_id_.size() == exprs_.size());
    auto pos = expr_id_.insert(expressions::Constant(value));
    if (pos.second) {
        exprs_.push_back(z3::expr(*context_));
        assert(pos.first->second < exprs_.size());
        switch (value.get_type()) {
        case police::Value::Type::BOOL:
            exprs_[pos.first->second] =
                (context_->bool_val(static_cast<bool>(value)));
            break;
        case police::Value::Type::INT:
            exprs_[pos.first->second] =
                (context_->int_val(static_cast<int_t>(value)));
            break;
        case police::Value::Type::REAL:
            exprs_[pos.first->second] = (context_->real_val(
                std::to_string(static_cast<real_t>(value)).c_str()));
            break;
        default: POLICE_UNREACHABLE();
        }
    }
    return exprs_[pos.first->second];
}

size_t Z3Environment::num_variables() const
{
    return vars_.size();
}

void Z3Environment::push_snapshot()
{
    var_snapshots_.push_back(vars_.size());
}

void Z3Environment::pop_snapshot()
{
    for (auto it = vars_.begin() + var_snapshots_.back(); it != vars_.end();
         ++it) {
        assert(z3_id_to_var_.count(it->id()));
        z3_id_to_var_.erase(z3_id_to_var_.find(it->id()));
    }
    vars_.erase(vars_.begin() + var_snapshots_.back(), vars_.end());
    vspace_.erase(vspace_.begin() + var_snapshots_.back(), vspace_.end());
    var_snapshots_.pop_back();
}

} // namespace police

#endif
