#include "police/expressions/disjunctive_normal_form.hpp"
#include "police/expressions/boolean_combination.hpp"
#include "police/expressions/constants.hpp"
#include "police/expressions/expression_transformer.hpp"
#include "police/expressions/expression_normalizer.hpp"
#include "police/macros.hpp"
#include "police/utils/algorithms.hpp"

#include <memory>

namespace police::expressions {

namespace {

class DisjunctionPuller final : public ExpressionTransformer {
public:
    void visit(Expression& ptr, Constant&) override { result = ptr; }

    void visit(Expression& ptr, IdentifierReference&) override { result = ptr; }

    void visit(Expression& ptr, Variable&) override { result = ptr; }

    void visit(Expression& ptr, NumericOperation&) override { result = ptr; }

    void visit(Expression& ptr, FunctionCall&) override { result = ptr; }

    void visit(Expression& ptr, BinaryFunctionCall&) override { result = ptr; }

    void visit(Expression& ptr, Comparison&) override { result = ptr; }

    void multiply(
        vector<Expression>& base_conjuncts,
        vector<vector<Expression>>& disjuncts)
    {
        if (disjuncts.empty()) {
            result = Conjunction(std::move(base_conjuncts));
        } else {
            vector<Expression> conjuncts;
            product(
                disjuncts,
                [&conjuncts](const vector<const Expression*>& conj) {
                    assert(!conj.empty());
                    Expression e = *conj.front();
                    for (auto i = 1u; i < conj.size(); ++i) {
                        e = e && *conj[i];
                    }
                    conjuncts.push_back(std::move(e));
                });
            if (conjuncts.size() == 1) {
                result = std::move(conjuncts[0]);
            } else {
                result = Disjunction(std::move(conjuncts));
            }
        }
    }

    void visit(Expression&, Conjunction& expr) override
    {
        vector<Expression> conjuncts;
        vector<vector<Expression>> disjuncts;
        for (auto& x : expr.children) {
            x.transform(*this);
            auto dis = std::dynamic_pointer_cast<Disjunction>(result.base());
            if (dis != nullptr) {
                if (dis->children.empty()) {
                    result = Constant(Value(false));
                    return;
                }
                disjuncts.push_back(dis->children);
            } else {
                auto con =
                    std::dynamic_pointer_cast<Conjunction>(result.base());
                if (con != nullptr) {
                    conjuncts.insert(
                        conjuncts.end(),
                        con->children.begin(),
                        con->children.end());
                } else {
                    conjuncts.push_back(std::move(result));
                }
            }
        }
        multiply(conjuncts, disjuncts);
    }

    void visit(Expression&, Disjunction& expr) override
    {
        vector<Expression> disjs;
        for (auto& x : expr.children) {
            x.transform(*this);
            auto dis = std::dynamic_pointer_cast<Disjunction>(result.base());
            if (dis != nullptr) {
                disjs.insert(
                    disjs.end(),
                    dis->children.begin(),
                    dis->children.end());
            } else {
                disjs.push_back(std::move(result));
            }
        }
        result = Disjunction(std::move(disjs));
    }

    void visit(Expression& ptr, Negation&) override { result = ptr; }

    void visit(Expression&, IfThenElse&) override
    {
        POLICE_RUNTIME_ERROR("if-then-else not supported by DNF transformer");
    }

    void visit(Expression& ptr, Derivative&) override { result = ptr; }

    Expression result;
};

} // namespace

expressions::Expression to_disjunctive_normal_form(expressions::Expression expr)
{
    DisjunctionPuller dnf;
    expr.transform(dnf);
    ExpressionNormalizer normalize;
    dnf.result.transform(normalize);
    return dnf.result;
}

} // namespace police::expressions
