#pragma once

#include "police/expressions/boolean_combination.hpp"
#include "police/expressions/constants.hpp"
#include "police/expressions/expression.hpp"
#include "police/expressions/expression_transformer.hpp"
#include "police/expressions/constant_fetcher.hpp"
#include "police/storage/value.hpp"

#include <optional>
#include <utility>

namespace police::expressions {

class DistributePlusOverMultiply final : public ExpressionTransformer {
public:
    void visit(Expression& ptr, NumericOperation& expr) override;
    void visit(Expression& ptr, Constant& expr) override;

    void visit(Expression& ptr, IdentifierReference& expr) override
    {
        fallback(ptr, expr);
    }
    void visit(Expression& ptr, Variable& expr) override
    {
        fallback(ptr, expr);
    }
    void visit(Expression& ptr, FunctionCall& expr) override
    {
        fallback(ptr, expr);
    }
    void visit(Expression& ptr, BinaryFunctionCall& expr) override
    {
        fallback(ptr, expr);
    }
    void visit(Expression& ptr, Comparison& expr) override
    {
        fallback(ptr, expr);
    }
    void visit(Expression& ptr, Conjunction& expr) override
    {
        fallback(ptr, expr);
    }
    void visit(Expression& ptr, Disjunction& expr) override
    {
        fallback(ptr, expr);
    }
    void visit(Expression& ptr, Negation& expr) override
    {
        fallback(ptr, expr);
    }
    void visit(Expression& ptr, IfThenElse& expr) override
    {
        fallback(ptr, expr);
    }
    void visit(Expression& ptr, Derivative& expr) override
    {
        fallback(ptr, expr);
    }

private:
    template <typename E>
    void fallback(Expression& ptr, E&& expr)
    {
        const auto factor = factor_;
        factor_ = std::nullopt;
        ExpressionTransformer::visit(ptr, std::forward<E>(expr));
        if (factor.has_value()) {
            ptr = NumericOperation(
                NumericOperation::Operand::MULTIPLY,
                MakeConstant()(factor.value()),
                ptr);
        }
        factor_ = factor;
    }

    ConstantFetcher get_const_;
    std::optional<Value> factor_ = std::nullopt;
};

class PushNegationInwards final : public ExpressionTransformer {
public:
    void visit(Expression& ptr, Comparison& expr) override;

    void visit(Expression& ptr, Conjunction& expr) override;

    void visit(Expression& ptr, Disjunction& expr) override;

    void visit(Expression& ptr, Negation& expr) override;

    void visit(Expression& ptr, IfThenElse& expr) override;

    void visit(Expression& ptr, Constant& expr) override
    {
        fallback(ptr, expr);
    }

    void visit(Expression& ptr, IdentifierReference& expr) override
    {
        fallback(ptr, expr);
    }

    void visit(Expression& ptr, Variable& expr) override
    {
        fallback(ptr, expr);
    }

    void visit(Expression& ptr, NumericOperation& expr) override
    {
        fallback(ptr, expr);
    }

    void visit(Expression& ptr, FunctionCall& expr) override
    {
        fallback(ptr, expr);
    }

    void visit(Expression& ptr, BinaryFunctionCall& expr) override
    {
        fallback(ptr, expr);
    }

private:
    template <typename E>
    void fallback(Expression& ptr, E&& expr)
    {
        const bool was_negated = is_negated_;
        is_negated_ = false;
        ExpressionTransformer::visit(ptr, std::forward<E>(expr));
        if (was_negated) {
            ptr = !ptr;
        }
        is_negated_ = was_negated;
    }

    bool is_negated_ = false;
};

template <typename Connector>
class CollapseBooleanConnectors final : public ExpressionTransformer {
public:
    void visit(Expression& ptr, Connector& expr) override;
};

template class CollapseBooleanConnectors<Conjunction>;
using CollapseConjunctions = CollapseBooleanConnectors<Conjunction>;

template class CollapseBooleanConnectors<Disjunction>;
using CollapseDisjunctions = CollapseBooleanConnectors<Disjunction>;

class ExpressionNormalizer final : public ExpressionTransformer {
public:
    void visit(Expression& ptr, Constant&) override { generic_visit(ptr); }
    void visit(Expression& ptr, IdentifierReference&) override
    {
        generic_visit(ptr);
    }
    void visit(Expression& ptr, Variable&) override { generic_visit(ptr); }
    void visit(Expression& ptr, NumericOperation&) override
    {
        generic_visit(ptr);
    }
    void visit(Expression& ptr, FunctionCall&) override { generic_visit(ptr); }
    void visit(Expression& ptr, BinaryFunctionCall&) override
    {
        generic_visit(ptr);
    }
    void visit(Expression& ptr, Comparison&) override { generic_visit(ptr); }
    void visit(Expression& ptr, Conjunction&) override { generic_visit(ptr); }
    void visit(Expression& ptr, Disjunction&) override { generic_visit(ptr); }
    void visit(Expression& ptr, Negation&) override { generic_visit(ptr); }
    void visit(Expression& ptr, IfThenElse&) override { generic_visit(ptr); }
    void visit(Expression& ptr, Derivative&) override { generic_visit(ptr); }

private:
    void generic_visit(Expression& ptr);

    DistributePlusOverMultiply distr_plus_multiply_;
    PushNegationInwards push_negs_;
    CollapseConjunctions collapse_conjs_;
    CollapseDisjunctions collapse_disjs_;
};

} // namespace police::expressions
