#include "police/verifiers/ic3/sat_based/pic3sat_label.hpp"
#include "police/arguments.hpp"
#include "police/defaults.hpp"
#include "police/option.hpp"
#include "police/action.hpp"
#include "police/model.hpp"
#include "police/variable_substitution.hpp"
#include "police/nnlp_encoders.hpp"
#include "police/lp_model_encoding.hpp"
#include "police/lp_policy_encoding.hpp"

#include <algorithm>
#include <memory>

namespace police::ic3 {

PIC3SatEdgeIndividual::PIC3SatEdgeIndividual(
    const PIC3SatParameters& params,
    const Model& model,
    const NeuralNetworkPolicy& policy,
    const vector<LinearConstraintDisjunction>& not_terminal,
    const LinearCondition& goal)
{
    assert(!model.actions.empty());
    std::unique_ptr<vector<Action>> determinized = nullptr;
    const vector<Action>* actions = &model.actions;
    if (params.determinize) {
        determinized = std::make_unique<vector<Action>>(model.determinize());
        actions = determinized.get();
    }
    auto la =
        std::find_if(actions->begin(), actions->end(), [](const Action& a) {
            return a.label != SILENT_ACTION;
        });
    assert(std::all_of(actions->begin(), la, [](const Action& a) {
        return a.label == SILENT_ACTION;
    }));
    assert(std::all_of(la, actions->end(), [](const Action& a) {
        return a.label != SILENT_ACTION;
    }));
    if (la != actions->begin()) {
        initialize_sat_silent_edges(
            *params.smt,
            model.variables,
            not_terminal,
            goal,
            actions->begin(),
            la);
    }
    vector<bool> pruned(std::distance(la, actions->end()), false);
    policy_ = std::make_shared<ic3::MultiSat<ic3::SatInterfaceNNLP>>(
        initialize_sat_policy_edges(
            *params.nnlp,
            model.variables,
            not_terminal,
            goal,
            policy,
            la,
            actions->end(),
            pruned,
            params.app_filter,
            params.edge_pruning));
    if (params.non_policy_prefilter) {
        non_policy_prefilter_ =
            std::make_shared<IC3SatInterface>(IC3SatInterface::create(
                *params.smt,
                model.variables,
                actions->begin(),
                actions->end(),
                not_terminal,
                goal));
    }
}

void PIC3SatEdgeIndividual::initialize_sat_silent_edges(
    const SMTFactory& smt_factory,
    const VariableSpace& vspace,
    const vector<LinearConstraintDisjunction>& not_terminal,
    const LinearCondition& goal,
    vector<Action>::const_iterator first,
    vector<Action>::const_iterator last)
{
    silent_ = std::make_shared<IC3SatInterface>(IC3SatInterface::create(
        smt_factory,
        vspace,
        first,
        last,
        not_terminal,
        goal));
}

vector<ic3::SatInterfaceNNLP>
PIC3SatEdgeIndividual::initialize_sat_policy_edges(
    const NNLPFactory& nnlp_factory,
    const VariableSpace& vspace,
    const vector<LinearConstraintDisjunction>& not_terminal,
    const LinearCondition& goal,
    const NeuralNetworkPolicy& policy,
    vector<Action>::const_iterator first,
    vector<Action>::const_iterator last,
    vector<bool>& pruned,
    bool app_filter,
    bool prune_edges)
{
    assert(first != last);
    assert(std::is_sorted(first, last, [](const auto& a, const auto& b) {
        return a.label < b.label;
    }));
    assert((last - 1)->label + 1 == policy.get_output().size());

    vector<ic3::SatInterfaceNNLP> result;
    result.reserve(std::distance(first, last));
    per_edge_nnlp_.reserve(std::distance(first, last));

    // special handling of first edge (need to process formulas, mapping the
    // variable indices into the NNLP)
    per_edge_nnlp_.push_back(nnlp_factory.make_shared());
    auto* lp = per_edge_nnlp_.back().get();
    auto info = initialize_nnlp(*lp, vspace);
    info.action_vars = add_nn(*lp, info, policy);
    const auto frame_var = lp->add_variable(BoundedRealType(0));
    vector<LinearConstraintDisjunction> x_term;
    {
        auto term = substitute_vars_container(not_terminal, info.in_vars);
        for (auto& d : term) {
            const auto status = presolve(*lp, d);
            assert(status != PresolveStatus::UNSAT);
            if (status != PresolveStatus::SAT) {
                x_term.push_back(std::move(d));
            }
        }
    }

    auto x_goal = substitute_vars(goal, info.out_vars);
    {
        LinearConstraint constr(LinearConstraint::GREATER_EQUAL);
        constr.insert(frame_var, 1.);
        constr.rhs = 2.;
        x_goal |= std::move(constr);
        presolve(*lp, x_goal);
    }
    const size_t num_actions = info.action_vars.size();

    // finish setup: add non-terminal constraints, goal frame, and policy
    // transition
    auto edge = first;
    auto edge_idx = 0u;
    for (; edge != last && pruned[edge_idx]; ++edge, ++edge_idx) {
    }
    assert(edge != last);

    bool some_pruned = false;

    for (;;) {
        add_no_terminals(*lp, x_term);
        add_guard(*lp, info, edge->guard);

        if (app_filter) {
            add_policy_selection_constraints_with_app_filter(
                *lp,
                info,
                edge->label,
                first,
                last);
        } else {
            add_policy_selection_constraints_without_app_filter(
                *lp,
                info,
                edge->label,
                num_actions);
        }

        if (prune_edges && lp->solve() == NNLP::UNSAT) {
            std::cout << "pruning edge #" << edge_idx << std::endl;
            per_edge_nnlp_.pop_back();
            pruned[edge_idx] = true;
            some_pruned = true;
        } else {
            add_successor_constraints_for_edge(*lp, info, *edge);
            encode_linear_condition(*lp, x_goal);

            result.emplace_back(lp, frame_var, info.in_vars, info.out_vars);
        }

        // proceed to next edge
        ++edge;
        ++edge_idx;
        for (; edge != last && pruned[edge_idx]; ++edge, ++edge_idx) {
        }
        if (edge == last) {
            break;
        }

        // prepare lp for next edge
        per_edge_nnlp_.push_back(nnlp_factory.make_shared());
        lp = per_edge_nnlp_.back().get();
#ifndef NDEBUG
        auto info_ = initialize_nnlp(*lp, vspace);
        info_.action_vars = add_nn(*lp, info, policy);
        assert(info_.in_vars == info.in_vars);
        assert(info_.out_vars == info.out_vars);
        assert(info_.action_vars == info.action_vars);
#else
        initialize_nnlp(*lp, vspace);
        add_nn(*lp, info, policy);
#endif
        // frame_var
        lp->add_variable(BoundedRealType(0));
    }

    if (app_filter && some_pruned) {
        return initialize_sat_policy_edges(
            nnlp_factory,
            vspace,
            not_terminal,
            goal,
            policy,
            first,
            last,
            pruned,
            app_filter,
            false);
    }

    return result;
}

std::pair<bool, size_t>
PIC3SatEdgeIndividual::is_blocked(const ic3::Cube& cube, size_t frame)
{
    if (non_policy_prefilter_) {
        prefiltered_ = false;
        auto blcked = non_policy_prefilter_->is_blocked(cube, frame);
        if (blcked.first) {
            prefiltered_ = true;
            return blcked;
        }
    }
    if (silent_ != nullptr) {
        const auto s = silent_->is_blocked(cube, frame);
        if (!s.first) {
            return s;
        }
        const auto p = policy_->is_blocked(cube, frame);
        return {p.first, std::min(p.second, s.second)};
    }
    return policy_->is_blocked(cube, frame);
}

void PIC3SatEdgeIndividual::set_blocked(const ic3::Cube& cube, size_t frame)
{
    policy_->set_blocked(cube, frame);
    if (silent_) {
        silent_->set_blocked(cube, frame);
    }
    if (non_policy_prefilter_) {
        non_policy_prefilter_->set_blocked(cube, frame);
    }
}

void PIC3SatEdgeIndividual::clear_frames()
{
    policy_->clear_frames();
    if (silent_) {
        silent_->clear_frames();
    }
    if (non_policy_prefilter_) {
        non_policy_prefilter_->clear_frames();
    }
}

void PIC3SatEdgeIndividual::add_frame()
{
    policy_->add_frame();
    if (silent_) {
        silent_->add_frame();
    }
    if (non_policy_prefilter_) {
        non_policy_prefilter_->add_frame();
    }
}

void PIC3SatEdgeIndividual::dump(std::ostream& out) const
{
    if (silent_ != nullptr) {
        silent_->get_smt()->dump(out);
    }
    const auto& edges = policy_->base();
    for (size_t i = 0; i < edges.size(); ++i) {
        edges[i].get_lp()->dump();
    }
}

ic3::Cube PIC3SatEdgeIndividual::get_unsat_core() const
{
    if (prefiltered_) return non_policy_prefilter_->get_unsat_core();
    ic3::Cube core = policy_->get_unsat_core();
    if (silent_) {
        core &= (silent_->get_unsat_core());
    }
    return core;
}

namespace {
PointerOption<SatInterfaceOption> _opt(
    "label_separate",
    [](const Arguments& args) {
        PIC3SatParameters params(
            args.get_ptr<SMTFactory>("smt_solver"),
            args.get_ptr<NNLPFactory>("lp_solver"),
            args.get<bool>("app_filter"),
            args.get<bool>("determinize"),
            args.get<bool>("prune_edges"),
            args.get<bool>("non_policy_prefilter"));
        return std::make_shared<EdgeIndividualSatOption>(std::move(params));
    },
    [](ArgumentsDefinition& defs) {
        defs.add_ptr_argument<NNLPFactory>("lp_solver", "NNLP solver");
        defs.add_ptr_argument<SMTFactory>(
            "smt_solver",
            "SMT solver",
            DEFAULT_SMT_SOLVER);
        defs.add_argument<bool>(
            "app_filter",
            "Applicability filtering",
            "false");
        defs.add_argument<bool>(
            "determinize",
            "Create separate NNLPs for each non-deterministic action outcome.",
            "true");
        defs.add_argument<bool>(
            "prune_edges",
            "Find and prune edges that will never be taken.",
            "false");
        defs.add_argument<bool>(
            "non_policy_prefilter",
            "Check first whether any transition exists.",
            "false");
    });
} // namespace

} // namespace police::ic3
