#include "police/expressions/const_expression_folder.hpp"
#include "police/expressions/constants.hpp"
#include "police/expressions/expression.hpp"

#include <algorithm>

namespace police::expressions {

void ConstExpressionFolder::visit(Expression& ptr, NumericOperation& expr)
{
    ExpressionTransformer::visit(ptr, expr);
    const auto left = get_const_(expr.left);
    const auto right = get_const_(expr.right);
    if (left.has_value() && right.has_value()) {
        ptr =
            MakeConstant()(evaluate(expr.operand, left.value(), right.value()));
    }
}

void ConstExpressionFolder::visit(Expression& ptr, FunctionCall& expr)
{
    ExpressionTransformer::visit(ptr, expr);
    const auto value = get_const_(expr.expr);
    if (value.has_value()) {
        ptr = MakeConstant()(evaluate(expr.function, value.value()));
    }
}

void ConstExpressionFolder::visit(Expression& ptr, BinaryFunctionCall& expr)
{
    ExpressionTransformer::visit(ptr, expr);
    const auto left = get_const_(expr.left);
    const auto right = get_const_(expr.right);
    if (left.has_value() && right.has_value()) {
        ptr = MakeConstant()(
            evaluate(expr.function, left.value(), right.value()));
    }
}

void ConstExpressionFolder::visit(Expression& ptr, Comparison& expr)
{
    ExpressionTransformer::visit(ptr, expr);
    const auto left = get_const_(expr.left);
    const auto right = get_const_(expr.right);
    if (left.has_value() && right.has_value()) {
        ptr = MakeConstant()(evaluate(expr.op, left.value(), right.value()));
    }
}

void ConstExpressionFolder::visit(Expression& ptr, Conjunction& expr)
{
    ExpressionTransformer::visit(ptr, expr);
    bool any_false = false;
    auto split = std::stable_partition(
        expr.children.begin(),
        expr.children.end(),
        [&](auto&& child) {
            const auto value = get_const_(child);
            if (value.has_value()) {
                any_false = any_false || !static_cast<bool>(value.value());
                return !static_cast<bool>(value.value());
            }
            return true;
        });
    if (any_false) {
        ptr = MakeConstant()(false);
    } else if (split == expr.children.begin()) {
        ptr = MakeConstant()(true);
    } else {
        expr.children.erase(split, expr.children.end());
    }
}

void ConstExpressionFolder::visit(Expression& ptr, Disjunction& expr)
{
    ExpressionTransformer::visit(ptr, expr);
    bool any_true = false;
    auto split = std::stable_partition(
        expr.children.begin(),
        expr.children.end(),
        [&](auto&& child) {
            const auto value = get_const_(child);
            if (value.has_value()) {
                any_true = any_true || static_cast<bool>(value.value());
                return static_cast<bool>(value.value());
            }
            return true;
        });
    if (any_true) {
        ptr = MakeConstant()(true);
    } else if (split == expr.children.begin()) {
        ptr = MakeConstant()(false);
    } else {
        expr.children.erase(split, expr.children.end());
    }
}

void ConstExpressionFolder::visit(Expression& ptr, IfThenElse& expr)
{
    ExpressionTransformer::visit(ptr, expr);
    const auto cond = get_const_(expr.condition);
    if (cond.has_value()) {
        ptr = static_cast<bool>(cond.value()) ? expr.consequence
                                              : expr.alternative;
    }
}

void ConstExpressionFolder::visit(Expression& ptr, Negation& expr)
{
    ExpressionTransformer::visit(ptr, expr);
    const auto val = get_const_(expr.expr);
    if (val.has_value()) {
        ptr = MakeConstant()(static_cast<bool>(val.value()));
    }
}

} // namespace police::expressions
