#pragma once

#include "police/base_types.hpp"
#include "police/macros.hpp"
#include "police/expressions/expression.hpp"
#include <cassert>
#include <cmath>

namespace police::expressions {

class BinaryFunctionCall final : public RawExpression {
public:
    enum struct Function { POW, LOG };

    BinaryFunctionCall(
        Function function,
        Expression left,
        Expression right);

    void accept(ExpressionVisitor& visitor) const override;
    void
    transform(Expression& myptr, ExpressionTransformer& transformer) override;
    bool is_same(const Expression& expr) const override;

    bool operator==(const BinaryFunctionCall& other) const;

    std::size_t hash() const override;

    void dump(std::ostream& out) const override;

    Function function;
    Expression left;
    Expression right;
};

class BinaryFunctionCallGeneric final : public RawExpression {
public:
    enum Function {
        MIN,
        MAX,
    };

    BinaryFunctionCallGeneric(Function fn, Expression x, Expression y);

    void accept(ExpressionVisitor& visitor) const override;
    void
    transform(Expression& myptr, ExpressionTransformer& transformer) override;
    bool is_same(const Expression& expr) const override;

    bool operator==(const BinaryFunctionCallGeneric& other) const;

    std::size_t hash() const override;

    void dump(std::ostream& out) const override;

    Function function;
    Expression left;
    Expression right;
};

template <typename Left, typename Right>
police::real_t
evaluate(BinaryFunctionCall::Function fn, Left&& left, Right&& right)
{
    switch (fn) {
    case police::expressions::BinaryFunctionCall::Function::POW:
        return std::pow(
            static_cast<police::real_t>(left),
            static_cast<police::real_t>(right));
    case police::expressions::BinaryFunctionCall::Function::LOG:
        return std::log(static_cast<police::real_t>(left)) /
               std::log(static_cast<police::real_t>(right));
    }
    POLICE_UNREACHABLE();
}

template <typename T>
T evaluate(
    BinaryFunctionCallGeneric::Function fn,
    const T& left,
    const T& right)
{
    switch (fn) {
    case police::expressions::BinaryFunctionCallGeneric::Function::MIN:
        return std::min(left, right);
    case police::expressions::BinaryFunctionCallGeneric::Function::MAX:
        return std::max(left, right);
    }
    POLICE_UNREACHABLE();
}

} // namespace police::expressions
