#include "police/expressions/expression_normalizer.hpp"
#include "police/expressions/constants.hpp"
#include "police/expressions/expression_transformer.hpp"
#include "police/expressions/numeric_operation.hpp"
#include <algorithm>
#include <memory>

namespace police::expressions {

void DistributePlusOverMultiply::visit(Expression& ptr, Constant& expr)
{
    if (factor_.has_value()) {
        ptr = MakeConstant()(factor_.value() * Value(expr.value));
    }
}

void DistributePlusOverMultiply::visit(Expression& ptr, NumericOperation& expr)
{
    const auto old_factor = factor_;
    switch (expr.operand) {
    case NumericOperation::Operand::MODULO: fallback(ptr, expr); return;
    case NumericOperation::Operand::ADD: [[fallthrough]];
    case NumericOperation::Operand::SUBTRACT:
        expr.left.transform(*this);
        expr.right.transform(*this);
        return;
    case NumericOperation::Operand::MULTIPLY: {
        factor_ = get_const_(expr.left);
        if (factor_.has_value()) {
            if (old_factor.has_value()) {
                factor_ = old_factor.value() * factor_.value();
            }
            expr.right.transform(*this);
            ptr = expr.right;
            factor_ = old_factor;
            return;
        } else {
            factor_ = get_const_(expr.right);
            if (factor_.has_value() && old_factor.has_value()) {
                factor_ = old_factor.value() * factor_.value();
            }
        }
        break;
    }
    case NumericOperation::Operand::DIVISION: {
        factor_ = get_const_(expr.right);
        if (factor_.has_value()) {
            if (old_factor.has_value()) {
                factor_ = old_factor.value() / factor_.value();
            } else {
                factor_ = Value(static_cast<real_t>(1.)) / factor_.value();
            }
        }
        break;
    }
    }
    if (factor_.has_value()) {
        expr.left.transform(*this);
        ptr = expr.left;
    } else {
        factor_ = old_factor;
        expr.left.transform(*this);
        factor_ = std::nullopt;
        expr.right.transform(*this);
    }
    factor_ = old_factor;
}

void PushNegationInwards::visit(Expression& ptr, Comparison& expr)
{
    const bool was_negated = is_negated_;
    if (was_negated) {
        switch (expr.op) {
        case Comparison::Operator::EQUAL:
            expr.op = Comparison::Operator::NOT_EQUAL;
            break;
        case Comparison::Operator::NOT_EQUAL:
            expr.op = Comparison::Operator::EQUAL;
            break;
        case Comparison::Operator::LESS:
            expr.op = Comparison::Operator::LESS_EQUAL;
            std::swap(expr.left, expr.right);
            break;
        case Comparison::Operator::LESS_EQUAL:
            expr.op = Comparison::Operator::LESS;
            std::swap(expr.left, expr.right);
            break;
        }
    }
    is_negated_ = false;
    ExpressionTransformer::visit(ptr, expr);
    is_negated_ = was_negated;
}

void PushNegationInwards::visit(Expression& ptr, Conjunction& expr)
{
    ExpressionTransformer::visit(ptr, expr);
    if (is_negated_) {
        ptr = Disjunction(std::move(expr.children));
    }
}

void PushNegationInwards::visit(Expression& ptr, Disjunction& expr)
{
    ExpressionTransformer::visit(ptr, expr);
    if (is_negated_) {
        ptr = Conjunction(std::move(expr.children));
    }
}

void PushNegationInwards::visit(Expression& ptr, Negation& expr)
{
    const bool was_negated = is_negated_;
    is_negated_ = !is_negated_;
    ExpressionTransformer::visit(ptr, expr);
    is_negated_ = was_negated;
    ptr = expr.expr;
}

void PushNegationInwards::visit(Expression&, IfThenElse& expr)
{
    const bool was_negated = is_negated_;
    expr.consequence.transform(*this);
    expr.alternative.transform(*this);
    is_negated_ = false;
    expr.condition.transform(*this);
    is_negated_ = was_negated;
}

template <typename Connector>
void CollapseBooleanConnectors<Connector>::visit(
    Expression& ptr,
    Connector& expr)
{
    ExpressionTransformer::visit(ptr, expr);
    std::vector<Expression> children;
    children.reserve(expr.children.size());
    for (auto i = 0u; i < expr.children.size(); ++i) {
        auto con =
            std::dynamic_pointer_cast<Connector>(expr.children[i].base());
        if (con) {
            children.insert(
                children.end(),
                con->children.begin(),
                con->children.end());
        } else {
            auto constant =
                std::dynamic_pointer_cast<Constant>(expr.children[i].base());
            if (constant) {
                if constexpr (std::is_same_v<Connector, Conjunction>) {
                    if (!static_cast<bool>(constant->value)) {
                        ptr = expr.children[i];
                        return;
                    }
                } else {
                    if (static_cast<bool>(constant->value)) {
                        ptr = expr.children[i];
                        return;
                    }
                }
            } else {
                children.push_back(std::move(expr.children[i]));
            }
        }
    }
    expr.children.swap(children);
}

void ExpressionNormalizer::generic_visit(Expression& ptr)
{
    ptr.transform(distr_plus_multiply_);
    ptr.transform(push_negs_);
    ptr.transform(collapse_conjs_);
    ptr.transform(collapse_disjs_);
}

} // namespace police::expressions
