#include "police/verifiers/ic3/sat_based/ic3_sat.hpp"
#include "police/variable_substitution.hpp"
#include "police/smt.hpp"
#include "police/smt_model_encoding.hpp"
#include "police/storage/variable_space.hpp"
#include "police/verifiers/ic3/sat_based/sat_interface/smt.hpp"

#include <algorithm>

namespace police::ic3 {

IC3SatInterface::IC3SatInterface(
    std::shared_ptr<SMT> smt,
    size_t frame_var,
    vector<size_t> in_vars,
    vector<size_t> out_vars)
    : ic3::SatInterfaceSMT(
          smt.get(),
          frame_var,
          std::move(in_vars),
          std::move(out_vars))
    , smt_(std::move(smt))
{
}

EncodingInformation
IC3SatInterface::initialize(SMT& smt, const VariableSpace& vspace)
{
    EncodingInformation info;
    add_state_variables(smt, info, vspace);
    return info;
}

void IC3SatInterface::add_no_terminals(
    SMT& smt,
    const EncodingInformation& info,
    const vector<LinearConstraintDisjunction>& not_terminal)
{
    if (!not_terminal.empty()) {
        std::for_each(
            not_terminal.begin(),
            not_terminal.end(),
            [&](const auto& c) {
                smt.add_constraint(
                    substitute_vars(c, info.in_vars).as_expression());
            });
    }
}

void IC3SatInterface::add_transitions(
    SMT& smt,
    const EncodingInformation& info,
    vector<Action>::const_iterator first,
    vector<Action>::const_iterator last)
{
    encode_transition(smt, info, first, last);
}

void IC3SatInterface::add_goal_frame(
    SMT& smt,
    size_t frame_var,
    const EncodingInformation& info,
    const LinearCondition& goal)
{
    LinearConstraint not_active(LinearConstraint::GREATER_EQUAL);
    not_active.insert(frame_var, 1.);
    not_active.rhs = 2;
    auto x = substitute_vars(goal, info.out_vars);
    x |= not_active;
    smt.add_constraint(x.as_expression());
}

IC3SatInterface IC3SatInterface::create(
    const SMTFactory& smt_factory,
    const VariableSpace& vspace,
    vector<Action>::const_iterator first,
    vector<Action>::const_iterator last,
    const vector<LinearConstraintDisjunction>& not_terminal,
    const LinearCondition& avoid)
{
    std::shared_ptr<SMT> smt = smt_factory.make_shared();
    auto info = initialize(*smt, vspace);
    add_no_terminals(*smt, info, not_terminal);
    add_transitions(*smt, info, first, last);

    auto frame_var = smt->add_variable("frame_var", RealType());
    add_goal_frame(*smt, frame_var, info, avoid);

    return IC3SatInterface(
        std::move(smt),
        frame_var,
        std::move(info.in_vars),
        std::move(info.out_vars));
}

} // namespace police::ic3
