#include "police/linear_constraint.hpp"
#include "police/base_types.hpp"
#include "police/expressions/comparison.hpp"
#include "police/expressions/expression_visitor.hpp"
#include "police/linear_expression.hpp"
#include "police/macros.hpp"

#include <sstream>

namespace police {

LinearConstraint::LinearConstraint(Type t)
    : type(t)
{
}

LinearConstraint::LinearConstraint(
    LinearCombination<size_t, real_t> lhs,
    real_t rhs,
    Type type)
    : LinearCombination<size_t, real_t>(std::move(lhs))
    , rhs(rhs)
    , type(type)
{
}

expressions::Expression LinearConstraint::as_expression() const
{
    switch (type) {
    case LESS_EQUAL:
        return less_equal(
            LinearCombination<size_t, real_t>::as_expression(),
            expressions::Constant(Value(rhs)));
    case GREATER_EQUAL:
        return greater_equal(
            LinearCombination<size_t, real_t>::as_expression(),
            expressions::Constant(Value(rhs)));
    case EQUAL:
        return equal(
            LinearCombination<size_t, real_t>::as_expression(),
            expressions::Constant(Value(rhs)));
    default: POLICE_UNREACHABLE();
    }
}

LinearConstraint less_equal(const LinearExpression& lhs, real_t rhs)
{
    return LinearConstraint(lhs, rhs - lhs.bias, LinearConstraint::LESS_EQUAL);
}

LinearConstraint
less_equal(const LinearExpression& lhs, const LinearExpression& rhs)
{
    return less_equal(lhs - rhs, 0.);
}

LinearConstraint greater_equal(const LinearExpression& lhs, real_t rhs)
{
    return LinearConstraint(
        lhs,
        rhs - lhs.bias,
        LinearConstraint::GREATER_EQUAL);
}

LinearConstraint equal(const LinearExpression& lhs, const LinearExpression& rhs)
{
    auto x = lhs - rhs;
    return LinearConstraint(x, -x.bias, LinearConstraint::EQUAL);
}

LinearConstraint equal(const LinearExpression& lhs, real_t rhs)
{
    return LinearConstraint(lhs, rhs - lhs.bias, LinearConstraint::EQUAL);
}

std::string LinearConstraint::to_string() const
{
    std::ostringstream oss;
    oss << LinearCombination<size_t, real_t>::to_string();
    if (size() == 0u) {
        oss << "0";
    }
    switch (type) {
    case LESS_EQUAL: oss << " <= "; break;
    case GREATER_EQUAL: oss << " >= "; break;
    case EQUAL: oss << " == "; break;
    }
    oss << rhs;
    return oss.str();
}

LinearConstraint LinearConstraint::operator-() const
{
    LinearConstraint copy(*this);
    copy.scale_coefs(-1.);
    copy.rhs *= -1.;
    switch (copy.type) {
    case LinearConstraint::LESS_EQUAL:
        copy.type = LinearConstraint::GREATER_EQUAL;
        break;
    case LinearConstraint::GREATER_EQUAL:
        copy.type = LinearConstraint::LESS_EQUAL;
        break;
    default: break;
    }
    return copy;
}

namespace {

class LinearConstraintConverter final : public expressions::ExpressionVisitor {
public:
    void visit(const expressions::BinaryFunctionCall&) override
    {
        throw_unsupported();
    }

    void visit(const expressions::Conjunction&) override
    {
        throw_unsupported();
    }

    void visit(const expressions::Disjunction&) override
    {
        throw_unsupported();
    }

    void visit(const expressions::FunctionCall&) override
    {
        throw_unsupported();
    }

    void visit(const expressions::IdentifierReference&) override
    {
        throw_unsupported();
    }

    void visit(const expressions::IfThenElse&) override { throw_unsupported(); }

    void visit(const expressions::Negation&) override { throw_unsupported(); }

    void visit(const expressions::Derivative&) override { throw_unsupported(); }

    void visit(const expressions::Constant&) override { throw_unsupported(); }

    void visit(const expressions::NumericOperation&) override
    {
        throw_unsupported();
    }

    void visit(const expressions::Variable&) override { throw_unsupported(); }

    void visit(const expressions::Comparison& expr) override
    {
        LinearExpression left = LinearExpression::from_expression(expr.left);
        LinearExpression right = LinearExpression::from_expression(expr.right);
        left -= right;
        switch (expr.op) {
        case expressions::Comparison::Operator::EQUAL:
            constraint = equal(left, 0.);
            break;
        case expressions::Comparison::Operator::LESS_EQUAL:
            constraint = less_equal(left, 0.);
            break;
        default:
            POLICE_RUNTIME_ERROR("Linear constraint doesn't support strict "
                                "inequality and unequal.");
        }
    }

    LinearConstraintConverter()
        : constraint(LinearConstraint::EQUAL)
    {
    }

    LinearConstraint constraint;

private:
    void throw_unsupported()
    {
        POLICE_RUNTIME_ERROR(
            "Cannot convert non linear constraint to a linear constraint");
    }
};

} // namespace

LinearConstraint
LinearConstraint::from_expression(const expressions::Expression& expr)
{
    LinearConstraintConverter cvt;
    expr.accept(cvt);
    return std::move(cvt.constraint);
}

LinearConstraint
LinearConstraint::unit_constraint(size_t var_id, Type type, real_t value)
{
    LinearConstraint c(type);
    c.insert(var_id, 1.);
    c.rhs = value;
    return c;
}

std::ostream&
operator<<(std::ostream& out, const police::LinearConstraint& constraint)
{
    return out << constraint.to_string();
}

} // namespace police
