#include "police/nnlp_model_encoding.hpp"

#include "police/action.hpp"
#include "police/constraint_factories.hpp"
#include "police/nnlp.hpp"
#include "police/nnlp_encoders.hpp"
#include "police/variable_substitution.hpp"

#include <algorithm>

namespace police {

vector<size_t> get_vars_in_input_order(
    const vector<size_t>& state_vars,
    const vector<size_t>& input_order)
{
    vector<size_t> ordered(input_order.size());
    for (int i = input_order.size() - 1; i >= 0; --i) {
        assert(input_order[i] < state_vars.size());
        ordered[i] = state_vars[input_order[i]];
    }
    return ordered;
}

vector<size_t> get_action_vars(const vector<size_t>& nn_output, size_t offset)
{
    vector<size_t> action_vars(nn_output.size());
    for (int i = nn_output.size() - 1; i >= 0; --i) {
        assert(nn_output[i] < nn_output.size());
        action_vars[nn_output[i]] = offset + i;
    }
    return action_vars;
}

EncodingInformation initialize_nnlp(NNLP& lp, const VariableSpace& vspace)
{
    EncodingInformation info;
    add_state_variables(lp, info, vspace);
    return info;
}

[[maybe_unused]]
void add_no_terminals(
    NNLP& lp,
    const vector<LinearConstraintDisjunction>& not_terminal)
{
    std::for_each(not_terminal.begin(), not_terminal.end(), [&](const auto& x) {
        lp.add_constraint(x);
    });
}

void add_state_variables(
    NNLP& lp,
    EncodingInformation& info,
    const VariableSpace& vspace)
{
    assert(info.in_vars.empty());
    assert(info.out_vars.empty());
    info.in_vars.reserve(vspace.size());
    info.out_vars.reserve(vspace.size());
    for (auto i = 0u; i < vspace.size(); ++i) {
        info.in_vars.push_back(lp.add_variable(vspace[i].type));
        info.out_vars.push_back(lp.add_variable(vspace[i].type));
    }
}

void add_guard(
    NNLP& lp,
    const EncodingInformation& info,
    const LinearConstraintConjunction& guard)
{
    std::for_each(guard.begin(), guard.end(), [&](const auto& constraint) {
        lp.add_constraint(substitute_vars(constraint, info.in_vars));
    });
}

void add_successor_constraints_for_outcome(
    NNLP& lp,
    const EncodingInformation& info,
    const Outcome& outcome)
{
    const auto c = get_assignment_constraints(info, outcome);
    std::for_each(c.begin(), c.end(), [&](const LinearConstraint& c) {
        lp.add_constraint(c);
    });
}

void add_successor_constraints_for_edge(
    NNLP& lp,
    const EncodingInformation& info,
    const Action& edge)
{
    assert(!edge.outcomes.empty());
    if (edge.outcomes.size() == 1u) {
        add_successor_constraints_for_outcome(lp, info, edge.outcomes.back());
    } else {
        LinearCondition cond;
        cond.reserve(edge.outcomes.size());
        std::for_each(
            edge.outcomes.begin(),
            edge.outcomes.end(),
            [&](const auto& assignments) {
                cond |= get_assignment_constraints(info, assignments);
            });
        encode_linear_condition(lp, cond);
    }
}

} // namespace police
