#include "police/_bits/z3_expr_adapter.hpp"

#ifdef POLICE_Z3

#include "police/_bits/z3_env.hpp"
#include "police/expressions/binary_function_call.hpp"
#include "police/expressions/expressions.hpp"
#include "police/macros.hpp"

#include <algorithm>
#include <cassert>
#include <z3++.h>

namespace police {

Z3ExpressionAdapter::Z3ExpressionAdapter(Z3Environment* env)
    : env_(env)
{
}

z3::expr Z3ExpressionAdapter::recurse(const expressions::Expression& expr)
{
    expr.accept(*this);
    assert(expression != nullptr);
    z3::expr result = std::move(*expression);
    expression = nullptr;
    return result;
}

void Z3ExpressionAdapter::visit(const expressions::BinaryFunctionCall&)
{
    POLICE_NOT_SUPPORTED("function calls");
}

void Z3ExpressionAdapter::visit(
    const expressions::BinaryFunctionCallGeneric& expr)
{
    auto left = recurse(expr.left);
    auto right = recurse(expr.right);
    switch (expr.function) {
    case expressions::BinaryFunctionCallGeneric::MIN:
        expression = std::make_shared<z3::expr>(z3::min(left, right));
        break;
    case expressions::BinaryFunctionCallGeneric::MAX:
        expression = std::make_shared<z3::expr>(z3::max(left, right));
        break;
    default: POLICE_UNREACHABLE(); break;
    }
}

void Z3ExpressionAdapter::visit(const expressions::Conjunction& expr)
{
    assert(expr.children.size() > 0);
    auto result = recurse(expr.children[0]);
    std::for_each(
        expr.children.begin() + 1,
        expr.children.end(),
        [&](const auto& sub_expr) { result = result && recurse(sub_expr); });
    expression = std::make_shared<z3::expr>(std::move(result));
}

void Z3ExpressionAdapter::visit(const expressions::Disjunction& expr)
{
    assert(expr.children.size() > 0);
    auto result = recurse(expr.children[0]);
    std::for_each(
        expr.children.begin() + 1,
        expr.children.end(),
        [&](const auto& sub_expr) { result = result || recurse(sub_expr); });
    expression = std::make_shared<z3::expr>(std::move(result));
}

void Z3ExpressionAdapter::visit(const expressions::Comparison& expr)
{
    auto left = recurse(expr.left);
    auto right = recurse(expr.right);
    switch (expr.op) {
    case expressions::Comparison::Operator::LESS:
        expression = std::make_shared<z3::expr>(left < right);
        break;
    case expressions::Comparison::Operator::EQUAL:
        expression = std::make_shared<z3::expr>(left == right);
        break;
    case expressions::Comparison::Operator::NOT_EQUAL:
        expression = std::make_shared<z3::expr>(left != right);
        break;
    case expressions::Comparison::Operator::LESS_EQUAL:
        expression = std::make_shared<z3::expr>(left <= right);
        break;
    }
}

void Z3ExpressionAdapter::visit(const expressions::Constant& expr)
{
    expression = std::make_shared<z3::expr>(env_->add_constant(expr.value));
}

void Z3ExpressionAdapter::visit(const expressions::Derivative&)
{
    POLICE_NOT_SUPPORTED("derivatives");
}

void Z3ExpressionAdapter::visit(const expressions::FunctionCall&)
{
    POLICE_NOT_SUPPORTED("function calls");
}

void Z3ExpressionAdapter::visit(const expressions::IdentifierReference&)
{
    POLICE_RUNTIME_ERROR("Z3 adapter found identifier expression");
}

void Z3ExpressionAdapter::visit(const expressions::IfThenElse& expr)
{
    auto cond = recurse(expr.condition);
    auto then = recurse(expr.consequence);
    auto otherwise = recurse(expr.alternative);
    expression = std::make_shared<z3::expr>(z3::ite(cond, then, otherwise));
}

void Z3ExpressionAdapter::visit(const expressions::Negation& expr)
{
    expression = std::make_shared<z3::expr>(!recurse(expr.expr));
}

void Z3ExpressionAdapter::visit(const expressions::NumericOperation& expr)
{
    auto left = recurse(expr.left);
    auto right = recurse(expr.right);
    switch (expr.operand) {
    case police::expressions::NumericOperation::Operand::ADD:
        expression = std::make_shared<z3::expr>(left + right);
        break;
    case police::expressions::NumericOperation::Operand::SUBTRACT:
        expression = std::make_shared<z3::expr>(left - right);
        break;
    case police::expressions::NumericOperation::Operand::MULTIPLY:
        expression = std::make_shared<z3::expr>(left * right);
        break;
    case police::expressions::NumericOperation::Operand::DIVISION:
        expression = std::make_shared<z3::expr>(left / right);
        break;
    case police::expressions::NumericOperation::Operand::MODULO:
        expression = std::make_shared<z3::expr>(left % right);
        break;
    }
}

void Z3ExpressionAdapter::visit(const expressions::Variable& expr)
{
    expression = std::make_shared<z3::expr>(env_->get_variable(expr.var_id));
}

} // namespace police

#endif
