#include "police/verifiers/ic3/syntactic/policy_reasoner_boxes.hpp"
#include "police/macros.hpp"

#if POLICE_VERITAS

#include "police/model.hpp"
#include "police/option.hpp"
#include "police/storage/variable_space.hpp"

#include <veritas/addtree.hpp>
#include <veritas/basics.hpp>
#include <veritas/tree.hpp>

#ifndef NDEBUG
#include <veritas/box.hpp>
#include <veritas/fp_search.hpp>
#endif

namespace police::ic3::syntactic {

void PolicyReasonerBoxes::collect_box_constraints(
    vector<bool>& lbs,
    vector<bool>& ubs,
    const flat_state& state,
    const vector<size_t>& input,
    const AddTree& addtree)
{
    assert(lbs.size() == ubs.size() && lbs.size() <= state.size());
    std::function<void(const veritas::Tree&, veritas::NodeId)> collect;
    collect = [&](const veritas::Tree& tree, veritas::NodeId node) {
        if (tree.is_leaf(node)) {
            return;
        }
        const auto& split = tree.get_split(node);
        assert(
            split.feat_id < static_cast<int>(input.size()) &&
            input[split.feat_id] < lbs.size());
        const auto var_id = input[split.feat_id];
        const real_t value = static_cast<real_t>(state[var_id]);
        if (split.test(value)) {
            ubs[var_id] = true;
            collect(tree, tree.left(node));
        } else {
            lbs[var_id] = true;
            collect(tree, tree.right(node));
        }
    };
    for (auto tree = addtree.begin(); tree != addtree.end(); ++tree) {
        collect(*tree, tree->root());
    }
}

PolicyReasonerBoxes::PolicyReasonerBoxes(
    const AddTreePolicy* policy,
    const VariableSpace* variables)
    : policy_(policy)
    , variables_(variables)
{
}

SuffCondAlternatives PolicyReasonerBoxes::get_reason(
    const flat_state& state,
    const LinearConstraintConjunction&,
    [[maybe_unused]] size_t label)
{
    vector<bool> lbs(variables_->size(), false);
    vector<bool> ubs(variables_->size(), false);
    collect_box_constraints(
        lbs,
        ubs,
        state,
        policy_->get_input(),
        *policy_->get_addtree());
    SufficientCondition result;
    for (size_t var = 0; var < lbs.size(); ++var) {
        if (lbs[var]) {
            if (ubs[var]) {
                result.emplace_back(var, VariableCondition::EQUALITY);
            } else {
                result.emplace_back(var, VariableCondition::LOWER_BOUND);
            }
        } else if (ubs[var]) {
            result.emplace_back(var, VariableCondition::UPPER_BOUND);
        }
    }
#ifndef NDEBUG
    {
        AddTree add_tree(*policy_->get_addtree());
        veritas::Config config =
            veritas::Config(veritas::HeuristicType::MULTI_MAX_MAX_OUTPUT_DIFF);
        config.ignore_state_when_worse_than = 0.0;
        config.multi_ignore_state_when_class0_worse_than =
            -15000; // Prune if T0[x] < -50
        config.stop_when_optimal = false;
        config.stop_when_num_solutions_exceeds = 1;
        config.stop_when_num_new_solutions_exceeds = 1;
        config.max_memory =
            7ull * 1024ull * 1024ull * 1024ull; // set mem to 3GB

        auto solve = [&](const veritas::FlatBox& box) {
            auto last_search = config.get_search(add_tree, box);
            veritas::StopReason r = veritas::StopReason::NONE;
            for (; r != veritas::StopReason::NUM_SOLUTIONS_EXCEEDED &&
                   r != veritas::StopReason::NO_MORE_OPEN;
                 r = last_search->steps(100));
            assert(
                r == veritas::StopReason::NUM_SOLUTIONS_EXCEEDED ||
                r == veritas::StopReason::NO_MORE_OPEN);
            return r != veritas::StopReason::NO_MORE_OPEN;
        };

        veritas::FlatBox box;
        box.reserve(policy_->get_input().size());
        for (auto var : policy_->get_input()) {
            const auto& var_type = variables_->get_type(var);
            box.emplace_back(
                static_cast<real_t>(
                    lbs[var] ? state[var] : var_type.get_lower_bound()),
                static_cast<real_t>(
                    ubs[var] ? state[var] : var_type.get_upper_bound()) +
                    1.);
        }

        // mark action label as primary class
        add_tree.swap_class(
            std::distance(
                policy_->get_output().begin(),
                std::find(
                    policy_->get_output().begin(),
                    policy_->get_output().end(),
                    label)));

        assert(!solve(box));
    }
#endif

    return {std::move(result)};
}

namespace {
PointerOption<PolicyReasoner> _opt(
    "tree_paths",
    [](const Arguments& args) -> std::shared_ptr<PolicyReasoner> {
        if (!args.has_addtree_policy()) {
            POLICE_EXIT_INVALID_INPUT(
                "tree_paths policy reasoner only supports tree policies");
        }
        if (args.applicability_masking) {
            POLICE_EXIT_INVALID_INPUT(
                "tree_paths policy reasoner doesn't "
                "support applicability masking");
        }
        const auto& policy = args.get_addtree_policy();
        return std::make_shared<PolicyReasonerBoxes>(
            &policy,
            &args.get_model().variables);
    });
} // namespace

} // namespace police::ic3::syntactic

#else

namespace police::ic3::syntactic {

PolicyReasonerBoxes::PolicyReasonerBoxes(
    const AddTreePolicy*,
    const VariableSpace*)
{
    POLICE_MISSING_DEPENDENCY("Veritas");
}

SuffCondAlternatives PolicyReasonerBoxes::get_reason(
    const flat_state&,
    const LinearConstraintConjunction&,
    size_t)
{
    POLICE_MISSING_DEPENDENCY("Veritas");
}

} // namespace police::ic3::syntactic

#endif
