#pragma once

#include "police/expressions/binary_function_call.hpp"
#include "police/expressions/boolean_combination.hpp"
#include "police/expressions/constants.hpp"
#include "police/expressions/expression_visitor.hpp"
#include "police/expressions/expressions.hpp"
#include "police/expressions/numeric_operation.hpp"
#include <cmath>
#include <optional>
#include <type_traits>
#include <utility>

namespace police::expressions {

template <
    typename IdentifierLookup = void*,
    typename VariableLookup = void*,
    bool Optional = true>
class ExpressionEvaluator final : public ExpressionVisitor {
public:
    using value_t = police::ite_t<Optional, std::optional<Value>, Value>;

    explicit ExpressionEvaluator(
        IdentifierLookup identifiers = IdentifierLookup(),
        VariableLookup variables = VariableLookup())
        : identifiers(std::move(identifiers))
        , variables(std::move(variables))
        , value(false)
    {
        clear();
    }

    void visit(const Constant& expr) override { value = expr.value; }

    void visit(const IdentifierReference& expr) override
    {
        lookup(identifiers, expr.identifier);
    }

    void visit(const Variable& expr) override
    {
        lookup(variables, expr.var_id);
    }

    void visit(const IfThenElse& expr) override
    {
        expr.condition.accept(*this);
        if constexpr (Optional) {
            if (!value.has_value()) {
                return;
            }
        }
        if (static_cast<bool>(get_value())) {
            expr.consequence.accept(*this);
        } else {
            expr.alternative.accept(*this);
        }
    }

    void visit(const Conjunction& expr) override
    {
        value = true;
        for (const auto& child : expr.children) {
            child.accept(*this);
            if constexpr (Optional) {
                if (!value.has_value() || !static_cast<bool>(get_value()))
                    return;
            } else if (!static_cast<bool>(get_value()))
                return;
        }
    }

    void visit(const Disjunction& expr) override
    {
        bool undef = false;
        value = false;
        for (const auto& child : expr.children) {
            child.accept(*this);
            if constexpr (Optional) {
                if (!value.has_value())
                    undef = true;
                else if (static_cast<bool>(get_value()))
                    return;
            } else if (static_cast<bool>(get_value()))
                return;
        }
        if constexpr (Optional) {
            if (undef) clear();
        }
    }

    void visit(const Negation& expr) override
    {
        expr.expr.accept(*this);
        if constexpr (Optional) {
            if (value.has_value()) {
                value = !static_cast<bool>(get_value());
            }
        } else {
            value = !static_cast<bool>(get_value());
        }
    }

    void visit(const Comparison& expr) override
    {
        expr.left.accept(*this);
        if constexpr (Optional) {
            if (!value.has_value()) return;
        }
        const Value left = get_value();
        expr.right.accept(*this);
        if constexpr (Optional) {
            if (!value.has_value()) return;
        }
        value = evaluate(expr.op, left, get_value());
    }

    void visit(const NumericOperation& expr) override
    {
        expr.left.accept(*this);
        if constexpr (Optional) {
            if (!value.has_value()) return;
        }
        const Value left = get_value();
        expr.right.accept(*this);
        if constexpr (Optional) {
            if (!value.has_value()) return;
        }
        value = evaluate(expr.operand, left, get_value());
    }

    void visit(const BinaryFunctionCall& expr) override
    {
        expr.left.accept(*this);
        if constexpr (Optional) {
            if (!value.has_value()) return;
        }
        const Value left = get_value();
        expr.right.accept(*this);
        if constexpr (Optional) {
            if (!value.has_value()) return;
        }
        value = evaluate(expr.function, left, get_value());
    }

    void visit(const FunctionCall& expr) override
    {
        expr.expr.accept(*this);
        if constexpr (Optional) {
            if (!value.has_value()) {
                return;
            }
        }
        value = evaluate(expr.function, get_value());
    }

    void visit(const Derivative&) override
    {
        // runtime error: this operation is not
        // support yet
        assert(false);
    }

    IdentifierLookup identifiers;
    VariableLookup variables;
    value_t value{};

private:
    void clear(std::true_type) { value = std::nullopt; }

    void clear(std::false_type) {}

    void clear() { clear(std::integral_constant<bool, Optional>{}); }

    Value& get_value(std::true_type) { return value.value(); }
    Value& get_value(std::false_type) { return value; }
    Value& get_value()
    {
        return get_value(std::integral_constant<bool, Optional>{});
    }

    template <typename T, typename... Args>
    void lookup(T& fn, Args&&... args)
    {
        clear();
        lookup_(
            std::integral_constant<bool, !std::is_same_v<T, void*>>{},
            fn,
            std::forward<Args>(args)...);
    }

    template <typename T, typename... Args>
    void lookup_(std::true_type, T& fn, Args&&... args)
    {
        value = fn(std::forward<Args>(args)...);
    }

    template <typename T, typename... Args>
    void lookup_(std::false_type, T&, Args&&...)
    {
        // runtime error: expression tree must not contain
        // identifier/variable references
        assert(Optional);
    }
};

template <typename Lookup>
Value evaluate(const expressions::Expression& expr, Lookup&& lookup)
{
    ExpressionEvaluator<void*, std::remove_cvref_t<Lookup>, false> eval(
        nullptr,
        std::forward<Lookup>(lookup));
    expr.accept(eval);
    return eval.value;
}

} // namespace police::expressions
