#include "police/verifiers/ic3/sat_based/pic3sat_singleton.hpp"
#include "police/defaults.hpp"
#include "police/option.hpp"
#include "police/option_parser.hpp"
#include "police/model.hpp"
#include "police/variable_substitution.hpp"
#include "police/macros.hpp"
#include "police/nnlp_encoders.hpp"
#include "police/nnlp_lp.hpp"
#include "police/encoding_information.hpp"
#include "police/lp_model_encoding.hpp"
#include "police/lp_policy_encoding.hpp"
#include "police/storage/variable_space.hpp"
#include "police/verifiers/ic3/sat_based/sat_interface/nnlp.hpp"

#include <memory>

namespace police::ic3 {

namespace {
EncodingInformation base_setup(
    NNLPLP& lp,
    const VariableSpace& variables,
    const NeuralNetworkPolicy& policy,
    const vector<LinearConstraintDisjunction>& not_terminal)
{
    auto info = initialize_nnlp(lp, variables);
    info.action_vars = add_nn(lp, info, policy);
    add_no_terminals(lp, substitute_vars_container(not_terminal, info.in_vars));
    return info;
}

void add_goal_frame(
    NNLP& lp,
    size_t frame_var,
    const LinearCondition& goal,
    const EncodingInformation& info)
{
    auto x_goal = substitute_vars(goal, info.out_vars);
    {
        LinearConstraint constr(LinearConstraint::GREATER_EQUAL);
        constr.insert(frame_var, 1.);
        constr.rhs = 2.;
        x_goal |= std::move(constr);
        presolve(lp, x_goal);
    }
    encode_linear_condition(lp, x_goal);
}
} // namespace

PIC3SatSingleton::PIC3SatSingleton(
    const PIC3SatSingletonParameters& params,
    const Model& model,
    const NeuralNetworkPolicy& policy,
    const vector<LinearConstraintDisjunction>& not_terminal,
    const LinearCondition& goal)
{
    auto determinized = model.determinize();
    auto la = std::find_if(
        determinized.begin(),
        determinized.end(),
        [](const Action& a) { return a.label != SILENT_ACTION; });
    if (la != determinized.begin()) {
        silent_ = std::make_shared<IC3SatInterface>(IC3SatInterface::create(
            *params.smt,
            model.variables,
            determinized.begin(),
            la,
            not_terminal,
            goal));
    }
    std::shared_ptr<NNLPLP> lp =
        std::dynamic_pointer_cast<NNLPLP>(params.nnlp->make_shared());
    if (lp == nullptr) {
        POLICE_RUNTIME_ERROR("PIC3 singleton sat interface requires NNLPLP");
    }
    auto info = base_setup(*lp, model.variables, policy, not_terminal);
    const auto frame_var = lp->add_variable(BoundedRealType(0));
    add_goal_frame(*lp, frame_var, goal, info);
    if (params.app_filter) {
        encode_all_transitions_with_app_filter(
            *lp,
            info,
            la,
            determinized.end());
    } else {
        encode_all_transitions_without_app_filter(
            *lp,
            info,
            la,
            determinized.end());
    }
    lp_ = std::move(lp);
    policy_ = std::make_shared<ic3::SatInterfaceNNLP>(
        lp_.get(),
        frame_var,
        info.in_vars,
        info.out_vars);
    if (params.non_policy_prefilter) {
        non_policy_prefilter_ =
            std::make_shared<IC3SatInterface>(IC3SatInterface::create(
                *params.smt,
                model.variables,
                model.actions.begin(),
                model.actions.end(),
                not_terminal,
                goal));
    }
}

std::pair<bool, size_t>
PIC3SatSingleton::is_blocked(const ic3::Cube& cube, size_t frame)
{
    if (non_policy_prefilter_) {
        prefiltered_ = false;
        auto blcked = non_policy_prefilter_->is_blocked(cube, frame);
        if (blcked.first) {
            prefiltered_ = true;
            return blcked;
        }
    }
    if (silent_ != nullptr) {
        const auto s = silent_->is_blocked(cube, frame);
        if (!s.first) {
            return s;
        }
        const auto p = policy_->is_blocked(cube, frame);
        return {p.first, std::min(p.second, s.second)};
    }
    return policy_->is_blocked(cube, frame);
}

void PIC3SatSingleton::set_blocked(const ic3::Cube& cube, size_t frame)
{
    policy_->set_blocked(cube, frame);
    if (silent_) {
        silent_->set_blocked(cube, frame);
    }
    if (non_policy_prefilter_) {
        non_policy_prefilter_->set_blocked(cube, frame);
    }
}

void PIC3SatSingleton::clear_frames()
{
    policy_->clear_frames();
    if (silent_) {
        silent_->clear_frames();
    }
    if (non_policy_prefilter_) {
        non_policy_prefilter_->clear_frames();
    }
}

void PIC3SatSingleton::add_frame()
{
    policy_->add_frame();
    if (silent_) {
        silent_->add_frame();
    }
    if (non_policy_prefilter_) {
        non_policy_prefilter_->add_frame();
    }
}

void PIC3SatSingleton::dump(std::ostream& out) const
{
    if (silent_ != nullptr) {
        silent_->get_smt()->dump(out);
    }
    lp_->dump();
}

ic3::Cube PIC3SatSingleton::get_unsat_core() const
{
    if (prefiltered_) return non_policy_prefilter_->get_unsat_core();
    ic3::Cube core = policy_->get_unsat_core();
    if (silent_) {
        core &= (silent_->get_unsat_core());
    }
    return core;
}

namespace {
PointerOption<SatInterfaceOption> _opt(
    "singleton",
    [](const Arguments& args) {
        PIC3SatSingletonParameters params(
            args.get_ptr<SMTFactory>("smt_solver"),
            args.get_ptr<NNLPFactory>("lp_solver"),
            args.get<bool>("app_filter"),
            args.get<bool>("non_policy_prefilter"));
        return std::make_shared<SingletonSatOption>(std::move(params));
    },
    [](ArgumentsDefinition& defs) {
        defs.add_ptr_argument<NNLPFactory>("lp_solver", "NNLP solver", "lp");
        defs.add_ptr_argument<SMTFactory>(
            "smt_solver",
            "SMT solver",
            DEFAULT_SMT_SOLVER);
        defs.add_argument<bool>(
            "app_filter",
            "Applicability filtering",
            "false");
        defs.add_argument<bool>(
            "non_policy_prefilter",
            "Check first whether any transition exists.",
            "false");
    });
} // namespace

} // namespace police::ic3
