#pragma once

#include "police/storage/value.hpp"
#include "police/utils/hash.hpp"

#include <memory>
#include <type_traits>

namespace police::expressions {

class ExpressionVisitor;
class ExpressionTransformer;
class Expression;
class RawExpression;

class RawExpression {
public:
    virtual ~RawExpression() = default;
    virtual void accept(ExpressionVisitor& visitor) const = 0;
    virtual void
    transform(Expression& ref, ExpressionTransformer& transform) = 0;
    [[nodiscard]]
    virtual bool is_same(const Expression& expr) const = 0;
    [[nodiscard]]
    virtual std::size_t hash() const = 0;
    virtual void dump(std::ostream& out) const = 0;
    [[nodiscard]]
    virtual bool is_constant() const;
    [[nodiscard]]
    virtual Value get_value() const;
};

class Expression {
    using RawExpressionPtr = std::shared_ptr<RawExpression>;

public:
    // implicit conversions
    Expression() = default;
    Expression(Value value);
    Expression(identifier_name_t name);

    template <
        typename T,
        typename =
            std::enable_if_t<std::is_base_of_v<RawExpression, std::decay_t<T>>>>
    Expression(T&& expr)
        : expr_(std::make_shared<std::decay_t<T>>(std::forward<T>(expr)))
    {
    }

    void accept(ExpressionVisitor& visitor) const;
    void transform(ExpressionTransformer& transform);

    [[nodiscard]]
    const RawExpressionPtr& base() const
    {
        return expr_;
    }

    [[nodiscard]]
    bool is_same(const Expression& other) const;

    [[nodiscard]]
    bool is_constant() const;

    [[nodiscard]]
    Value get_value() const;

private:
    RawExpressionPtr expr_ = nullptr;
};

[[nodiscard]]
Expression operator+(Expression left, Expression right);
[[nodiscard]]
Expression operator-(Expression left, Expression right);
[[nodiscard]]
Expression operator*(Expression left, Expression right);
[[nodiscard]]
Expression operator/(Expression left, Expression right);
[[nodiscard]]
Expression operator%(Expression left, Expression right);
[[nodiscard]]
Expression operator&&(Expression left, Expression right);
[[nodiscard]]
Expression operator||(Expression left, Expression right);
[[nodiscard]]
Expression operator!(Expression left);
[[nodiscard]]
Expression operator==(Expression left, Expression right);
[[nodiscard]]
Expression operator!=(Expression left, Expression right);
[[nodiscard]]
Expression operator<=(Expression left, Expression right);
[[nodiscard]]
Expression operator>=(Expression left, Expression right);
[[nodiscard]]
Expression operator<(Expression left, Expression right);
[[nodiscard]]
Expression operator>(Expression left, Expression right);
[[nodiscard]]
Expression greater(Expression left, Expression right);
[[nodiscard]]
Expression equal(Expression left, Expression right);
[[nodiscard]]
Expression not_equal(Expression left, Expression right);
[[nodiscard]]
Expression less_equal(Expression left, Expression right);
[[nodiscard]]
Expression greater_equal(Expression left, Expression right);
[[nodiscard]]
Expression greater(Expression left, Expression right);
[[nodiscard]]
Expression less(Expression left, Expression right);
[[nodiscard]]
Expression ite(Expression cond, Expression then, Expression otherwise);
[[nodiscard]]
Expression min(Expression x, Expression y);
[[nodiscard]]
Expression max(Expression x, Expression y);

} // namespace police::expressions

namespace police {

std::ostream&
operator<<(std::ostream& out, const expressions::Expression& expr);

template <>
struct hash<expressions::Expression> {
    [[nodiscard]]
    std::size_t operator()(const expressions::Expression& expr) const
    {
        return expr.base()->hash();
    }
};
} // namespace police

#define __POLICE_IMPLEMENT_EXPR_VISITORS(ClassName)                             \
    void ClassName::accept(ExpressionVisitor& visitor) const                   \
    {                                                                          \
        visitor.visit(*this);                                                  \
    }                                                                          \
                                                                               \
    void ClassName::transform(                                                 \
        Expression& ref,                                                       \
        ExpressionTransformer& transform)                                      \
    {                                                                          \
        return transform.visit(ref, *this);                                    \
    }                                                                          \
                                                                               \
    bool ClassName::is_same(const Expression& expr) const                      \
    {                                                                          \
        auto ptr = std::dynamic_pointer_cast<ClassName>(expr.base());          \
        return ptr != nullptr && (*this == *ptr);                              \
    }
