#include "police/smt_model_encoding.hpp"

#include "police/expressions/expression.hpp"
#include "police/expressions/variable.hpp"
#include "police/storage/variable_space.hpp"
#include "police/variable_substitution.hpp"

#include <algorithm>

namespace police {

void add_state_variables(
    SMT& smt,
    EncodingInformation& info,
    const VariableSpace& vspace)
{
    assert(info.in_vars.empty());
    assert(info.out_vars.empty());
    info.in_vars.reserve(vspace.size());
    info.out_vars.reserve(vspace.size());
    for (auto i = 0u; i < vspace.size(); ++i) {
        info.in_vars.push_back(smt.add_variable("", vspace[i].type));
        info.out_vars.push_back(smt.add_variable("", vspace[i].type));
    }
}

namespace {
void add_guard(
    SMT& smt,
    size_t appl_var,
    const vector<size_t>& vars,
    const LinearConstraintConjunction& guard)
{
    smt.add_constraint(
        !expressions::Variable(appl_var) ||
        substitute_vars(guard, vars).as_expression());
}

expressions::Expression get_assignment_constraint(
    const vector<size_t>& in_vars,
    const vector<size_t>& out_vars,
    const Assignment& assignment)
{
    assert(assignment.var_id < out_vars.size());
    return expressions::equal(
        expressions::Variable(out_vars[assignment.var_id]),
        substitute_vars(assignment.value, in_vars).as_expression());
}

expressions::Expression get_successor_constraints(
    const vector<size_t>& in_vars,
    const vector<size_t>& out_vars,
    const Outcome& outcome)
{
    const auto& assignments = outcome.assignments;
    assert(!assignments.empty());
    assert(std::is_sorted(
        assignments.begin(),
        assignments.end(),
        [](const auto& a, const auto& b) { return a.var_id < b.var_id; }));
    expressions::Expression result =
        get_assignment_constraint(in_vars, out_vars, assignments.front());
    size_t i = 0;
    auto assignment = assignments.begin() + 1;
    for (; assignment != assignments.end(); ++i) {
        if (i == assignments.front().var_id) continue;
        if (assignment->var_id == i) {
            result = result &&
                     get_assignment_constraint(in_vars, out_vars, *assignment);
            ++assignment;
        } else {
            result = result && expressions::equal(
                                   expressions::Variable(in_vars[i]),
                                   expressions::Variable(out_vars[i]));
        }
    }
    for (; i < out_vars.size(); ++i) {
        if (i == assignments.front().var_id) continue;
        result = result && expressions::equal(
                               expressions::Variable(in_vars[i]),
                               expressions::Variable(out_vars[i]));
    }
    return result;
}

expressions::Expression get_successor_constraints(
    const vector<size_t>& in_vars,
    const vector<size_t>& out_vars,
    size_t appl_var,
    const Action& edge)
{
    assert(!edge.outcomes.empty());
    const auto applicable = expressions::Variable(appl_var);
    auto succs =
        applicable &&
        get_successor_constraints(in_vars, out_vars, edge.outcomes.front());
    std::for_each(
        edge.outcomes.begin() + 1,
        edge.outcomes.end(),
        [&](const auto& outcome) {
            succs = succs ||
                    (applicable &&
                     get_successor_constraints(in_vars, out_vars, outcome));
        });
    return succs;
}

} // namespace

void encode_transition(
    SMT& smt,
    const EncodingInformation& info,
    vector<Action>::const_iterator first,
    vector<Action>::const_iterator last)
{
    assert(first != last);
    expressions::Expression succ_constraints;
    {
        const auto appl_var = smt.add_variable("", BoolType());
        add_guard(smt, appl_var, info.in_vars, first->guard);
        succ_constraints = get_successor_constraints(
            info.in_vars,
            info.out_vars,
            appl_var,
            *first);
    }
    for (auto it = first + 1; it != last; ++it) {
        const auto appl_var = smt.add_variable("", BoolType());
        add_guard(smt, appl_var, info.in_vars, it->guard);
        succ_constraints = succ_constraints || get_successor_constraints(
                                                   info.in_vars,
                                                   info.out_vars,
                                                   appl_var,
                                                   *it);
    }
    smt.add_constraint(succ_constraints);
}

} // namespace police
