#include "police/variable_substitution.hpp"

#include <iterator>

namespace police {

namespace {

LinearCombination<size_t, real_t> __substitute_vars(
    const LinearCombination<size_t, real_t>& combo,
    const vector<size_t>& vars)
{
    LinearCombination<size_t, real_t> result;
    result.reserve(combo.size());
    std::for_each(combo.begin(), combo.end(), [&](auto&& elem) {
        assert(elem.first < vars.size());
        result.insert(vars[elem.first], elem.second);
    });
    return result;
}

} // namespace

LinearConstraint
substitute_vars(const LinearConstraint& constraint, const vector<size_t>& vars)
{
    return {
        __substitute_vars(constraint, vars),
        constraint.rhs,
        constraint.type};
}

LinearExpression
substitute_vars(const LinearExpression& constraint, const vector<size_t>& vars)
{
    return {__substitute_vars(constraint, vars), constraint.bias};
}

namespace {
template <typename C>
C substitute_vars_collection(const C& collection, const vector<size_t>& vars)
{
    C res;
    res.reserve(collection.size());
    std::transform(
        collection.begin(),
        collection.end(),
        std::back_inserter(res),
        [&](auto&& x) { return substitute_vars(x, vars); });
    return res;
}
} // namespace

LinearConstraintDisjunction substitute_vars(
    const LinearConstraintDisjunction& constraint,
    const vector<size_t>& vars)
{
    return substitute_vars_collection(constraint, vars);
}

LinearConstraintConjunction substitute_vars(
    const LinearConstraintConjunction& constraint,
    const vector<size_t>& vars)
{
    return substitute_vars_collection(constraint, vars);
}

LinearCondition
substitute_vars(const LinearCondition& constraint, const vector<size_t>& vars)
{
    return substitute_vars_collection(constraint, vars);
}

} // namespace police
