#include "police/verifiers/ic3/syntactic/policy_reasoner_veritas.hpp"
#include "police/addtree.hpp"

#if POLICE_VERITAS

#include "police/verifiers/ic3/syntactic/applicability_conditioner.hpp"
#include "police/verifiers/ic3/syntactic/policy_reasoner_boxes.hpp"

#include <veritas/addtree.hpp>
#include <veritas/box.hpp>
#include <veritas/fp_search.hpp>
#include <veritas/interval.hpp>

namespace police::ic3::syntactic {

namespace {

constexpr size_t UNDEFINED = -1;
constexpr DecisionTreeValueType LARGE_CONSTANT = 1000'000;
constexpr DecisionTreeValueType EPSILON = 5e-5;

veritas::Config get_search_config()
{
    veritas::Config config =
        veritas::Config(veritas::HeuristicType::MULTI_MAX_MAX_OUTPUT_DIFF);
    config.ignore_state_when_worse_than = 0.0;
    config.multi_ignore_state_when_class0_worse_than =
        -15000; // Prune if T0[x] < -50
    config.stop_when_optimal = false;
    config.stop_when_num_solutions_exceeds = 1;
    config.stop_when_num_new_solutions_exceeds = 1;
    config.max_memory = 7ull * 1024ull * 1024ull * 1024ull; // set mem to 3GB
    return config;
}

} // namespace

PolicyReasonerVeritasBase::PolicyReasonerVeritasBase(
    const Model& model,
    const AddTreePolicy& policy,
    const vector<size_t>& var_order)
    : variables_(&model.variables)
    , addtree_(std::make_shared<AddTree>(*policy.get_addtree()))
    , input_features_(policy.get_input())
    , class_index_(model.labels.size(), UNDEFINED)
    , var_class_(classify_variables(model))
{
    size_t pos = 0;
    for (auto label : policy.get_output()) {
        assert(label < class_index_.size());
        class_index_[label] = pos;
        ++pos;
    }
    vector<size_t> rank(var_order.size());
    for (size_t order = 0; order < var_order.size(); ++order) {
        rank[var_order[order]] = order;
    }
    vector<std::pair<size_t, size_t>> ext;
    for (size_t feature_id = 0; feature_id < policy.get_input().size();
         ++feature_id) {
        size_t var = policy.get_input()[feature_id];
        ext.emplace_back(rank[var], feature_id);
    }
    std::sort(ext.begin(), ext.end());
    for (auto [rank, feature_id] : ext) {
        min_order_.push_back(feature_id);
    }
    base_scores_.resize(addtree_->num_leaf_values());
    for (size_t i = base_scores_.size(); i > 0; --i) {
        base_scores_[i - 1] = addtree_->base_score(i - 1);
    }
#ifndef POLICE_NO_STATISTICS
    stats_file_ = std::make_shared<std::ofstream>("synic3_veritas.stats");
#endif
}

template <bool ApplicabilityMasked>
SuffCondAlternatives PolicyReasonerVeritasBase::compute_reason(
    const flat_state& state,
    const LinearConstraintConjunction&,
    size_t label,
    ApplicabilityConditioner* applicability)
{
    assert(!ApplicabilityMasked || applicability != nullptr);

    if (class_index_[label] == UNDEFINED) {
        return {};
    }

    // veritas preparation
    auto config = get_search_config();
    auto solve = [&](const veritas::FlatBox& box) {
        auto last_search = config.get_search(*addtree_, box);
        veritas::StopReason r = veritas::StopReason::NONE;
        for (; r != veritas::StopReason::NUM_SOLUTIONS_EXCEEDED &&
               r != veritas::StopReason::NO_MORE_OPEN;
             r = last_search->steps(100));
        assert(
            r == veritas::StopReason::NUM_SOLUTIONS_EXCEEDED ||
            r == veritas::StopReason::NO_MORE_OPEN);
        return r != veritas::StopReason::NO_MORE_OPEN;
    };

#ifndef POLICE_NO_STATISTICS
    bool separator = false;
    StopWatch total_time;
    (*stats_file_) << "{\"action\": " << label << ", \"calls\": [";
#endif

    // get box constraints corresponding to state, skip testing those features
    // not relevant for the tree decision anyway
    vector<bool> lbs(variables_->size(), false);
    vector<bool> ubs(variables_->size(), false);
    PolicyReasonerBoxes::collect_box_constraints(
        lbs,
        ubs,
        state,
        input_features_,
        *addtree_);

    // state as box constraints; note: VeritasBase' intervals are lb inclusive
    // but ub exclusive
    veritas::FlatBox box;
    box.reserve(input_features_.size());
    for (size_t i = 0; i < input_features_.size(); ++i) {
        const auto var = input_features_[i];
        box.emplace_back(
            static_cast<real_t>(state[var]),
            static_cast<real_t>(state[var]) + EPSILON);
        // const auto& var_type = variables_->get_type(var);
        // box.emplace_back(
        //     static_cast<real_t>(
        //         lbs[var] ? state[var] : var_type.get_lower_bound()),
        //     static_cast<real_t>(
        //         ubs[var] ? state[var] : var_type.get_upper_bound()) +
        //         1.);
    }

    auto sync_applicability_with_box = [&]() {
        addtree_->base_score(0) = base_scores_[class_index_[label]];
        for (size_t label_ = 0; label_ < class_index_.size(); ++label_) {
            const auto cls = class_index_[label_];
            if (label != label_ && cls != UNDEFINED) {
                addtree_->base_score(cls == 0 ? class_index_[label] : cls) =
                    base_scores_[cls] -
                    (!(*applicability)[label_] ? LARGE_CONSTANT : 0.);
            }
        }
    };

    // mark action label as primary class
    addtree_->swap_class(class_index_[label]);

#ifndef NDEBUG
    if constexpr (ApplicabilityMasked) {
        sync_applicability_with_box();
    }
    assert(!solve(box));
#endif

    SufficientCondition result;
    for (size_t i = 0; i < min_order_.size(); ++i) {
        const size_t feature_id = min_order_[i];
        const size_t var_id = input_features_[feature_id];

        // std::cout << "feature#" << feature_id << " var" << var_id
        //           << " lb:" << lbs[var_id] << " ub:" << ubs[var_id]
        //           << std::endl;

        if (!lbs[var_id] && !ubs[var_id]) {
#ifndef POLICE_NO_STATISTICS
            (*stats_file_) << (separator ? ", " : "") << "{\"sat\": 0"
                           << ", \"time\": " << (0) << ", \"var\":" << var_id
                           << ", \"pruned\": 0}";
            separator = true;
#endif

            continue;
        }

#ifndef POLICE_NO_STATISTICS
        StopWatch w;
#endif

        const auto& var_type = variables_->get_type(var_id);

        // fully relax variable
        const auto old_box = box[feature_id];
        box[feature_id] = veritas::Interval(
            static_cast<real_t>(var_type.get_lower_bound()),
            static_cast<real_t>(var_type.get_upper_bound()) + EPSILON);

        // update applicability status and synchronize with relaxation vars
        if constexpr (ApplicabilityMasked) {
            assert(applicability != nullptr);
            applicability->assume_invalid(var_id);
            sync_applicability_with_box();
        }

        if (solve(box)) {
            // std::cout << "... cannot fully relax" << std::endl;
            // check for partial relaxation
            if (lbs[var_id] &&
                var_class_[var_id] == VariableCategory::LOWER_BOUNDED) {
                box[feature_id] = veritas::Interval(
                    static_cast<real_t>(state[var_id]),
                    static_cast<real_t>(var_type.get_upper_bound()) + EPSILON);
                if (solve(box)) {
                    // std::cout << "... cannot relax (lb)" << std::endl;
                    box[feature_id] = old_box;
                    result.emplace_back(var_id, VariableCondition::EQUALITY);
                    if constexpr (ApplicabilityMasked) {
                        applicability->revert_last_assumption();
                    }
                } else {
                    // std::cout << "... lb feasible" << std::endl;
                    result.emplace_back(var_id, VariableCondition::LOWER_BOUND);
                }
            } else if (
                ubs[var_id] &&
                var_class_[var_id] == VariableCategory::UPPER_BOUNDED) {
                box[feature_id] = veritas::Interval(
                    static_cast<real_t>(var_type.get_lower_bound()),
                    static_cast<real_t>(state[var_id]) + EPSILON);
                if (solve(box)) {
                    // std::cout << "... cannot relax (ub)" << std::endl;
                    box[feature_id] = old_box;
                    result.emplace_back(var_id, VariableCondition::EQUALITY);
                    if constexpr (ApplicabilityMasked) {
                        applicability->revert_last_assumption();
                    }
                } else {
                    // std::cout << "... ub feasible" << std::endl;
                    result.emplace_back(var_id, VariableCondition::UPPER_BOUND);
                }
            } else {
                // std::cout << "... cannot relax" << std::endl;
                box[feature_id] = old_box;
                result.emplace_back(var_id, VariableCondition::EQUALITY);
                if constexpr (ApplicabilityMasked) {
                    applicability->revert_last_assumption();
                }
            }
        }

#ifndef POLICE_NO_STATISTICS
        w.stop();
        (*stats_file_) << (separator ? ", " : "") << "{\"sat\": "
                       << (!result.empty() &&
                           result.back().variable_id == var_id)
                       << ", \"time\": "
                       << (static_cast<int>(w.get_milliseconds()))
                       << ", \"var\":" << var_id << ", \"pruned\": 1}";
        separator = true;
#endif
    }

    // std::cout << "final box: " << print_sequence(box) << std::endl;

#ifndef NDEBUG
    if constexpr (ApplicabilityMasked) {
        // sync applicability flags in case last iteration's assumption was
        // reverted
        sync_applicability_with_box();
    }
    assert(!solve(box));
#endif

    // restore classes
    addtree_->swap_class(class_index_[label]);

#ifndef POLICE_NO_STATISTICS
    (*stats_file_) << "], \"total_time\": "
                   << static_cast<int>(total_time.get_milliseconds()) << "}\n";
#endif

    return {std::move(result)};
}

SuffCondAlternatives PolicyReasonerVeritas::get_reason(
    const flat_state& state,
    const LinearConstraintConjunction& guard,
    size_t label)
{
    return compute_reason<false>(state, guard, label, nullptr);
}

SuffCondAlternatives PolicyReasonerVeritasMasked::get_reason(
    const flat_state& state,
    const LinearConstraintConjunction& guard,
    size_t label)
{
    ApplicabilityConditioner applicability(*model_, infos_);
    return compute_reason<true>(state, guard, label, &applicability);
}

void PolicyReasonerVeritasMasked::prepare(const flat_state& state)
{
    infos_ = ApplicabilityInformation(*model_, state);
}

} // namespace police::ic3::syntactic

#else

namespace police::ic3::syntactic {

PolicyReasonerVeritasBase::PolicyReasonerVeritasBase(
    const Model&,
    const AddTreePolicy&,
    const vector<size_t>&)
{
    POLICE_MISSING_DEPENDENCY("Veritas");
}

SuffCondAlternatives PolicyReasonerVeritas::get_reason(
    const flat_state&,
    const LinearConstraintConjunction&,
    size_t)
{

    POLICE_MISSING_DEPENDENCY("Veritas");
}

SuffCondAlternatives PolicyReasonerVeritasMasked::get_reason(
    const flat_state&,
    const LinearConstraintConjunction&,
    size_t)
{

    POLICE_MISSING_DEPENDENCY("Veritas");
}

void PolicyReasonerVeritasMasked::prepare(const flat_state&)
{

    POLICE_MISSING_DEPENDENCY("Veritas");
}

} // namespace police::ic3::syntactic

#endif
