#include "police/verifiers/ic3/syntactic/variable_classification.hpp"
#include <iterator>

namespace police::ic3::syntactic {

vector<VariableCategory> classify_variables(const Model& model)
{
    struct Info {
        size_t eq = 0;
        size_t leq = 0;
        size_t geq = 0;
        size_t non_unit = 0;
        size_t non_const_assign = 0;
    };

    vector<Info> infos(model.variables.size());
    for (const auto& action : model.actions) {
        for (const auto& cond : action.guard) {
            for (const auto& [var, coef] : cond) {
                Info& info = infos[var];
                info.eq += cond.type == LinearConstraint::EQUAL;
                info.leq +=
                    (cond.type == LinearConstraint::LESS_EQUAL && coef > 0.) ||
                    (cond.type == LinearConstraint::GREATER_EQUAL && coef < 0.);
                info.geq +=
                    (cond.type == LinearConstraint::LESS_EQUAL && coef < 0.) ||
                    (cond.type == LinearConstraint::GREATER_EQUAL && coef > 0.);
                info.non_unit += cond.size() != 1u;
            }
        }
        for (const auto& outcome : action.outcomes) {
            for (const auto& assign : outcome.assignments) {
                Info& info = infos[assign.var_id];
                info.non_const_assign += assign.value.size() > 0;
            }
        }
    }

    vector<VariableCategory> classes;
    classes.reserve(model.variables.size());
    std::transform(
        infos.begin(),
        infos.end(),
        std::back_inserter(classes),
        [](const Info& info) {
            if (info.non_const_assign == 0u && info.non_unit == 0u &&
                info.leq == 0u && info.geq == 0u) {
                return VariableCategory::CATEGORIAL;
            }
            if (info.leq == 0u) {
                return VariableCategory::LOWER_BOUNDED;
            }
            if (info.geq == 0u) {
                return VariableCategory::UPPER_BOUNDED;
            }
            return VariableCategory::GENERIC;
        });
    return classes;
}

} // namespace police::ic3::syntactic
