#include "police/linear_expression.hpp"
#include "police/base_types.hpp"
#include "police/expressions/expression_visitor.hpp"
#include "police/expressions/numeric_operation.hpp"

namespace police {

namespace {

class LinearExpressionConverter final : public expressions::ExpressionVisitor {
public:
    void visit(const expressions::BinaryFunctionCall&) override
    {
        throw_unsupported();
    }

    void visit(const expressions::Conjunction&) override
    {
        throw_unsupported();
    }

    void visit(const expressions::Disjunction&) override
    {
        throw_unsupported();
    }

    void visit(const expressions::Comparison&) override { throw_unsupported(); }

    void visit(const expressions::FunctionCall&) override
    {
        throw_unsupported();
    }

    void visit(const expressions::IdentifierReference&) override
    {
        throw_unsupported();
    }

    void visit(const expressions::IfThenElse&) override { throw_unsupported(); }

    void visit(const expressions::Negation&) override { throw_unsupported(); }

    void visit(const expressions::Derivative&) override { throw_unsupported(); }

    void visit(const expressions::Constant& expr) override
    {
        linex.bias = static_cast<real_t>(expr.value);
    }

    void visit(const expressions::NumericOperation& expr) override
    {
        if (expr.operand == expressions::NumericOperation::Operand::MODULO) {
            throw_unsupported();
        }

        expr.left.accept(*this);
        LinearExpression left = std::move(linex);
        linex.clear();

        expr.right.accept(*this);
        LinearExpression right = std::move(linex);

        switch (expr.operand) {
        case expressions::NumericOperation::Operand::DIVISION:
            if (!linex.empty()) {
                throw_unsupported();
            }
            if (right.bias == 0.) {
                POLICE_RUNTIME_ERROR("division by 0");
            }
            left /= right.bias;
            break;
        case police::expressions::NumericOperation::Operand::MULTIPLY:
            if (!left.empty() && !right.empty()) {
                throw_unsupported();
            }
            if (left.empty()) {
                std::swap(left, right);
            }
            left *= right.bias;
            break;
        case police::expressions::NumericOperation::Operand::ADD:
            left += right;
            break;
        case police::expressions::NumericOperation::Operand::SUBTRACT:
            left -= right;
            break;
        default: POLICE_UNREACHABLE();
        }

        linex = std::move(left);
    }

    void visit(const expressions::Variable& expr) override
    {
        linex.insert(expr.var_id, 1.);
    }

    LinearExpression linex;

private:
    void throw_unsupported()
    {
        POLICE_RUNTIME_ERROR(
            "Cannot convert non linear expression to a linear expression");
    }
};

} // namespace

LinearExpression
LinearExpression::from_expression(const expressions::Expression& expr)
{
    LinearExpressionConverter cvt;
    expr.accept(cvt);
    cvt.linex.remove_zero_coefficients();
    return std::move(cvt.linex);
}

LinearExpression LinearExpression::constant(real_t constant)
{
    LinearExpression e;
    e.bias = constant;
    return e;
}

LinearExpression
LinearExpression::unit(size_t var_id, real_t coef, real_t constant)
{
    LinearExpression e;
    e.bias = constant;
    e.insert(var_id, coef);
    return e;
}

std::ostream& operator<<(std::ostream& out, const police::LinearExpression& expr)
{
    return out << expr.to_string();
}

} // namespace police
