#include "police/linear_condition.hpp"
#include "police/expressions/disjunctive_normal_form.hpp"
#include "police/expressions/expression_visitor.hpp"
#include "police/expressions/expressions.hpp"

#include <algorithm>
#include <iterator>
#include <sstream>
#include <string>

namespace police {

LinearConstraintConjunction&
LinearConstraintConjunction::operator&=(LinearConstraint constraint)
{
    push_back(std::move(constraint));
    return *this;
}

LinearConstraintConjunction&
LinearConstraintConjunction::operator&=(const LinearConstraintConjunction& conj)
{
    insert(end(), conj.begin(), conj.end());
    return *this;
}

expressions::Expression LinearConstraintConjunction::as_expression() const
{
    vector<expressions::Expression> conjs;
    conjs.reserve(size());
    std::transform(
        begin(),
        end(),
        std::back_inserter(conjs),
        [](const auto& c) { return c.as_expression(); });
    return expressions::Conjunction(std::move(conjs));
}

std::string LinearConstraintConjunction::to_string(size_t ws) const
{
    if (size() == 0u) {
        return std::string(ws, ' ') + "TRUE";
    } else {
        std::ostringstream oss;
        bool first = true;
        for (const auto& c : *this) {
            oss << std::string(ws, ' ') << (first ? "    " : "AND ") << c
                << "\n";
            first = false;
        }
        return oss.str();
    }
}

LinearConstraintDisjunction
LinearConstraintDisjunction::operator|=(LinearConstraint constraint)
{
    push_back(std::move(constraint));
    return *this;
}

LinearConstraintDisjunction
LinearConstraintDisjunction::operator|=(const LinearConstraintDisjunction& disj)
{
    insert(end(), disj.begin(), disj.end());
    return *this;
}

expressions::Expression LinearConstraintDisjunction::as_expression() const
{
    vector<expressions::Expression> disjs;
    disjs.reserve(size());
    std::transform(
        begin(),
        end(),
        std::back_inserter(disjs),
        [](const auto& c) { return c.as_expression(); });
    return expressions::Disjunction(std::move(disjs));
}

std::string LinearConstraintDisjunction::to_string(size_t ws) const
{
    if (size() == 0u) {
        return std::string(ws, ' ') + "FALSE";
    } else {
        std::ostringstream oss;
        bool first = true;
        for (const auto& c : *this) {
            oss << std::string(ws, ' ') << (first ? "    " : "OR  ") << c
                << "\n";
            first = false;
        }
        return oss.str();
    }
}

LinearCondition& LinearCondition::operator|=(const LinearCondition& other)
{
    insert(end(), other.begin(), other.end());
    return *this;
}

LinearCondition& LinearCondition::operator|=(LinearConstraint other)
{
    emplace_back();
    back().push_back(std::move(other));
    return *this;
}

LinearCondition&
LinearCondition::operator|=(const LinearConstraintDisjunction& other)
{
    for (const auto& con : other) {
        operator|=(con);
    }
    return *this;
}

LinearCondition& LinearCondition::operator|=(LinearConstraintConjunction other)
{
    push_back(std::move(other));
    return *this;
}

expressions::Expression LinearCondition::as_expression() const
{
    vector<expressions::Expression> disjs;
    disjs.reserve(size());
    std::transform(
        begin(),
        end(),
        std::back_inserter(disjs),
        [](const auto& c) { return c.as_expression(); });
    return expressions::Disjunction(std::move(disjs));
}

std::string LinearCondition::to_string(size_t ws) const
{
    if (size() == 0u) {
        return std::string(ws, ' ') + "TRUE";
    } else {
        std::ostringstream oss;
        for (auto i = 0u; i < size(); ++i) {
            if (i == 0) {
                oss << std::string(ws, ' ') << "[" << "\n";
            } else {
                oss << std::string(ws, ' ') << "] OR [\n";
            }
            oss << at(i).to_string(ws + 2);
        }
        oss << "]\n";
        return oss.str();
    }
}

LinearConstraintConjunction operator&&(
    const LinearConstraintConjunction& a,
    const LinearConstraintConjunction& b)
{
    LinearConstraintConjunction temp(a);
    temp &= b;
    return temp;
}

LinearConstraintConjunction
operator&&(const LinearConstraint& a, const LinearConstraintConjunction& b)
{
    LinearConstraintConjunction temp(b);
    temp &= a;
    return temp;
}

LinearConstraintConjunction
operator&&(const LinearConstraintConjunction& a, const LinearConstraint& b)
{
    LinearConstraintConjunction temp(a);
    temp &= b;
    return temp;
}

LinearConstraintDisjunction operator||(
    const LinearConstraintDisjunction& a,
    const LinearConstraintDisjunction& b)
{
    LinearConstraintDisjunction temp(a);
    temp |= b;
    return temp;
}

LinearConstraintDisjunction
operator||(const LinearConstraint& a, const LinearConstraintDisjunction& b)
{
    LinearConstraintDisjunction temp(b);
    temp |= a;
    return temp;
}

LinearConstraintDisjunction
operator||(const LinearConstraintDisjunction& a, const LinearConstraint& b)
{
    LinearConstraintDisjunction temp(a);
    temp |= b;
    return temp;
}

LinearCondition operator||(const LinearCondition& a, const LinearCondition& b)
{
    LinearCondition temp(a);
    temp |= b;
    return temp;
}

LinearCondition operator||(const LinearConstraint& a, const LinearCondition& b)
{
    LinearCondition temp(b);
    temp |= a;
    return temp;
}

LinearCondition operator||(const LinearCondition& a, const LinearConstraint& b)
{
    LinearCondition temp(a);
    temp |= b;
    return temp;
}

LinearCondition
operator||(const LinearConstraintConjunction& a, const LinearCondition& b)
{
    LinearCondition temp(b);
    temp |= a;
    return temp;
}

LinearCondition
operator||(const LinearCondition& a, const LinearConstraintConjunction& b)
{
    LinearCondition temp(a);
    temp |= b;
    return temp;
}

LinearCondition
operator||(const LinearConstraintDisjunction& a, const LinearCondition& b)
{
    LinearCondition temp(b);
    temp |= a;
    return temp;
}

LinearCondition
operator||(const LinearCondition& a, const LinearConstraintDisjunction& b)
{
    LinearCondition temp(a);
    temp |= b;
    return temp;
}

namespace {
void negate(LinearConstraintDisjunction& result, LinearConstraint constraint)
{
    if (constraint.type == LinearConstraint::Type::EQUAL) {
        LinearConstraint lb(constraint);
        lb.type = LinearConstraint::LESS_EQUAL;
        lb.rhs -= 1;
        result.push_back(std::move(lb));
        constraint.type = LinearConstraint::GREATER_EQUAL;
        constraint.rhs += 1.;
        result.push_back(std::move(constraint));
    } else {
        constraint.rhs +=
            constraint.type == LinearConstraint::GREATER_EQUAL ? -1 : 1;
        constraint.type = constraint.type == LinearConstraint::GREATER_EQUAL
                              ? LinearConstraint::LESS_EQUAL
                              : LinearConstraint::GREATER_EQUAL;
        result.push_back(std::move(constraint));
    }
}
} // namespace

LinearConstraintDisjunction operator!(const LinearConstraintConjunction& conj)
{
    LinearConstraintDisjunction result;
    result.reserve(conj.size());
    std::for_each(conj.begin(), conj.end(), [&](auto& con) {
        negate(result, std::move(con));
    });
    return result;
}

vector<LinearConstraintDisjunction> operator!(const LinearCondition& cond)
{
    vector<LinearConstraintDisjunction> result;
    result.reserve(cond.size());
    std::transform(
        cond.begin(),
        cond.end(),
        std::back_inserter(result),
        [](const LinearConstraintConjunction& conj) { return !conj; });
    return result;
}

namespace {

class LinearConditionConverter final : public expressions::ExpressionVisitor {
public:
    void visit(const expressions::BinaryFunctionCall& e) override
    {
        throw_unsupported(e);
    }

    void visit(const expressions::FunctionCall& e) override
    {
        throw_unsupported(e);
    }

    void visit(const expressions::IdentifierReference& e) override
    {
        throw_unsupported(e);
    }

    void visit(const expressions::IfThenElse& e) override
    {
        throw_unsupported(e);
    }

    void visit(const expressions::Negation& e) override
    {
        throw_unsupported(e);
    }

    void visit(const expressions::Derivative& e) override
    {
        throw_unsupported(e);
    }

    void visit(const expressions::NumericOperation& e) override
    {
        throw_unsupported(e);
    }

    void visit(const expressions::Variable& e) override
    {
        throw_unsupported(e);
    }

    void visit(const expressions::Comparison& expr) override
    {
        assert(cond.empty());
        cond |= LinearConstraint::from_expression(expr);
    }

    void visit(const expressions::Conjunction& expr) override
    {
        assert(cond.empty());
        LinearConstraintConjunction conj;
        for (const auto& child : expr.children) {
            conj.push_back(LinearConstraint::from_expression(child));
        }
        cond |= std::move(conj);
    }

    void visit(const expressions::Disjunction& expr) override
    {
        assert(cond.empty());
        LinearCondition nf;
        for (const auto& child : expr.children) {
            child.accept(*this);
            nf |= std::move(cond);
            cond.clear();
        }
        cond = std::move(nf);
    }

    void visit(const expressions::Constant& e) override
    {
        if (static_cast<bool>(e.value)) {
            cond |= LinearConstraintConjunction();
        } else {
            cond.clear();
        }
    }

    LinearCondition cond;

private:
    void throw_unsupported(expressions::Expression expr)
    {
        std::cerr << expr << " is not a linear expression" << std::endl;
        POLICE_RUNTIME_ERROR(
            "Cannot convert non linear condition to a linear condition");
    }
};

} // namespace

LinearCondition
LinearCondition::from_expression(const expressions::Expression& expr)
{
    expressions::Expression in_dnf =
        expressions::to_disjunctive_normal_form(expr);
    LinearConditionConverter c;
    in_dnf.accept(c);
    return c.cond;
}

std::ostream&
operator<<(std::ostream& out, const police::LinearConstraintDisjunction& disj)
{
    return out << disj.to_string();
}

std::ostream&
operator<<(std::ostream& out, const police::LinearConstraintConjunction& conj)
{
    return out << conj.to_string();
}

std::ostream& operator<<(std::ostream& out, const police::LinearCondition& cond)
{
    return out << cond.to_string();
}

} // namespace police
