#include "police/verifiers/ic3/syntactic/policy_reasoner_nnlp.hpp"

#include "police/base_types.hpp"
#include "police/nnlp_model_encoding.hpp"
#include "police/ffnn_lp_encoder.hpp"
#include "police/nnlp_factory.hpp"
#include "police/utils/stopwatch.hpp"
#include "police/verifiers/ic3/syntactic/applicability_conditioner.hpp"
#include "police/verifiers/ic3/syntactic/variable_classification.hpp"

#include <algorithm>
#include <fstream>
#include <numeric>

namespace police::ic3::syntactic {

namespace {
constexpr real_t EPSILON = 5e-5;
constexpr real_t LARGE_CONSTANT = 1000000.;
} // namespace

PolicyReasonerNNLPBase::PolicyReasonerNNLPBase(
    NNLPFactory& nnlp_factory,
    const Model* model,
    const police::NeuralNetworkPolicy* policy,
    const vector<size_t>& var_order,
    double total_time_limit,
    double individual_time_limit,
    bool applicability_filter)
    : input_vars_(policy->get_input())
    , var_class_(classify_variables(*model))
    , limits_(model->variables.size(), total_time_limit, individual_time_limit)
    , model_(model)
    , policy_(policy)
{
    const auto& output_action = policy->get_output();
    // prepare nnlps
    nnlps_.reserve(output_action.size());
    size_t rel_var = 0;
    for (size_t i = 0; i < output_action.size(); ++i) {
        nnlps_.push_back(nnlp_factory.make_shared());
        for (size_t var = 0; var < model->variables.size(); ++var) {
            nnlps_.back()->add_variable(model->variables[var].type.relax());
        }
        if (applicability_filter) {
            rel_var = nnlps_.back()->num_variables();
            for (size_t i = 0; i < model->labels.size(); ++i) {
                nnlps_.back()->add_variable(RealType());
            }
        }
    }
    if (applicability_filter) {
        relaxation_vars_.resize(model->labels.size());
        std::iota(relaxation_vars_.begin(), relaxation_vars_.end(), rel_var);
    }
    // generate network for each action
    vector<size_t> vars(model->variables.size(), 0);
    std::iota(vars.begin(), vars.end(), 0);
    auto invars = get_vars_in_input_order(vars, policy->get_input());
    for (size_t i = 0; i < policy->get_output().size(); ++i) {
        const size_t i_action = output_action[i];
        assert(i_action < nnlps_.size());
        NNLP& action_lp = *nnlps_[i_action];
        const auto out = encode_ffnn_in_lp(action_lp, policy->get_nn(), invars);
        const size_t aux = action_lp.num_variables();
        for (int j = output_action.size() - 1; j > 0; --j) {
            action_lp.add_variable(RealType());
            action_lp.add_variable(RealType());
        }
        LinearConstraint c(LinearConstraint::LESS_EQUAL);
        c.rhs = 0;
        size_t k = 0;
        for (size_t j = 0; j < output_action.size(); ++j) {
            if (i != j) {
                const size_t j_action = output_action[j];
                {
                    LinearConstraint eq(LinearConstraint::EQUAL);
                    eq.insert(out + i, -1);
                    eq.insert(out + j, 1);
                    if (applicability_filter) {
                        eq.insert(relaxation_vars_[j_action], -1);
                    }
                    eq.insert(aux + 2 * k, -1);
                    eq.rhs = i_action > j_action ? -EPSILON : 0.;
                    action_lp.add_constraint(std::move(eq));
                }
                {
                    ReluConstraint relu(aux + 2 * k, aux + 2 * k + 1);
                    action_lp.add_constraint(std::move(relu));
                }
                c.insert(aux + 2 * k + 1, 1);
                ++k;
            }
        }
        action_lp.add_constraint(c);
    }
    vector<size_t> rank(var_order.size());
    for (size_t order = 0; order < var_order.size(); ++order) {
        rank[var_order[order]] = order;
    }
    std::sort(
        input_vars_.begin(),
        input_vars_.end(),
        [&rank](size_t i, size_t j) { return rank[i] > rank[j]; });
#ifndef POLICE_NO_STATISTICS
    stats_file_ = std::make_shared<std::ofstream>("synic3_pigreedy.stats");
#endif
}

#if 0
bool PolicyReasonerNNLPBase::can_be_selected(const Cube& cube, size_t action)
{
    auto& nnlp = nnlps_[action];
    for (const auto& [var, val] : cube) {
        LinearConstraint c(LinearConstraint::EQUAL);
        c.insert(var, 1.);
        c.rhs = static_cast<real_t>(val.lb);
        nnlp->add_assumption(std::move(c));
    }
    return nnlp->solve() != NNLP::UNSAT;
}
#endif

template <bool ApplicabilityMasked>
SuffCondAlternatives PolicyReasonerNNLPBase::compute_reason(
    std::integral_constant<bool, ApplicabilityMasked>,
    const flat_state& state,
    const LinearConstraintConjunction&,
    size_t label,
    ApplicabilityConditioner* applicability)
{
    limits_.start();

#ifndef POLICE_NO_STATISTICS
    bool separator = false;
    (*stats_file_) << "{\"action\": " << label << ", \"calls\": [";
#endif

    // prepare assumptions fixing variable values
    vector<std::pair<size_t, LinearConstraint>> assumptions;
    for (size_t j = 0; j < input_vars_.size(); ++j) {
        assumptions.emplace_back(
            input_vars_[j],
            LinearConstraint::unit_constraint(
                input_vars_[j],
                LinearConstraint::EQUAL,
                static_cast<real_t>(state[input_vars_[j]])));
    }

    // get NNLP associated with the given action label
    auto& nnlp = nnlps_[label];

    auto enforce_assumptions = [&]() {
        // enforce variable values as per the current partial state; skip
        // var as we want to check whether that constraint can be removed
        for (int j = assumptions.size() - 2; j >= 0; --j) {
            nnlp->add_assumption(assumptions[j].second);
        }

        // if applicability filter is enabled
        if constexpr (ApplicabilityMasked) {
            assert(applicability != nullptr);
            // set the relaxation variables
            // according to the applicability status of the actions
            for (size_t label_ = 0; label_ < model_->labels.size(); ++label_) {
                nnlp->add_assumption(LinearConstraint::unit_constraint(
                    relaxation_vars_[label],
                    LinearConstraint::EQUAL,
                    applicability->operator[](label) ? 0. : LARGE_CONSTANT));
            }
        }
    };

#ifndef NDEBUG
    nnlp->add_assumption(assumptions.back().second);
    enforce_assumptions();
    assert(nnlp->solve() == NNLP::UNSAT);
#endif

    SufficientCondition result;

    int i = input_vars_.size() - 1;
    for (; i >= 0 && !limits_.out_of_budget(); --i) {
        const size_t var = assumptions[i].first;
        auto step_limits = limits_.step(var);

        // check compute budget for variable, if budget exceeded skip ahead
        if (step_limits.out_of_budget()) {
            result.emplace_back(var, VariableCondition::EQUALITY);
            continue;
        }

        // move to back of assumption vector (allows simple pop_back() later)
        if (i != static_cast<int>(input_vars_.size() - 1)) {
            std::swap(assumptions[i], assumptions.back());
        }

        if constexpr (ApplicabilityMasked) {
            // update applicability informations ignoring conditions on var
            assert(applicability != nullptr);
            applicability->assume_invalid(var);
        }

        // enforce assumptions and solve NNLP
        enforce_assumptions();
        auto status = nnlp->solve();

        // if NNLP is unsolvable -> var can be ignored
        if (status == NNLP::UNSAT) {
            assumptions.pop_back();
        }

        // otherwise check relaxed conditions, where we enforce only
        // half-intervals
        else {
            if (var_class_[var] == VariableCategory::LOWER_BOUNDED ||
                var_class_[var] == VariableCategory::UPPER_BOUNDED) {
                // re-enforce assumptions (which were cleared after last solve()
                // call)
                enforce_assumptions();
                // additionally add half-interval constraint
                LinearConstraint relaxed(LinearConstraint::EQUAL);
                if (var_class_[var] == VariableCategory::LOWER_BOUNDED) {
                    relaxed = LinearConstraint::unit_constraint(
                        var,
                        LinearConstraint::GREATER_EQUAL,
                        static_cast<real_t>(state[var]));
                } else {
                    relaxed = (LinearConstraint::unit_constraint(
                        var,
                        LinearConstraint::LESS_EQUAL,
                        static_cast<real_t>(state[var])));
                }
                nnlp->add_assumption(relaxed);
                // check if now unsolvable
                status = nnlp->solve();
                if (status == NNLP::UNSAT) {
                    // -> add half-interval constraint to result
                    // Note: applcability conditions on var are still ignored,
                    // i.e., the set of actions potentially applicable in the
                    // resulting minimized state are still an overapproximation
                    assumptions.back().second = std::move(relaxed);
                    result.push_back(VariableCondition(
                        var,
                        var_class_[var] == VariableCategory::LOWER_BOUNDED
                            ? VariableCondition::LOWER_BOUND
                            : VariableCondition::UPPER_BOUND));
                }
            }

            // if NNLP remained solvable -> must include var=state[var] as
            // constraint
            if (status != NNLP::UNSAT) {
                result.push_back(
                    VariableCondition(var, VariableCondition::EQUALITY));
                // revert applicability conditions
                if constexpr (ApplicabilityMasked) {
                    applicability->revert_last_assumption();
                }
            }
        }

#ifndef POLICE_NO_STATISTICS
        (*stats_file_) << (separator ? ", " : "")
                       << "{\"sat\": " << (status != NNLP::UNSAT)
                       << ", \"time\": "
                       << (static_cast<int>(
                              step_limits.time.get_milliseconds()))
                       << ", \"var\":" << var << ", \"relaxed\": "
                       << (status == NNLP::UNSAT && !result.empty() &&
                           result.back().variable_id == var)
                       << "}";
        separator = true;
#endif
    }

    for (; i >= 0; --i) {
        const size_t var = assumptions[i].first;
        result.push_back(VariableCondition(var, VariableCondition::EQUALITY));
#ifndef POLICE_NO_STATISTICS
        (*stats_file_) << (separator ? ", " : "") << "{\"sat\": " << 1
                       << ", \"time\": " << -1 << ", \"var\":" << var
                       << ", \"relaxed\": " << 0 << "}";
        separator = true;
#endif
    }

#ifndef POLICE_NO_STATISTICS
    (*stats_file_) << "], \"total_time\": "
                   << static_cast<int>(limits_.get_time().get_milliseconds())
                   << "}\n";
#endif

    limits_.end();

    return {std::move(result)};
}

SuffCondAlternatives PolicyReasonerNNLP::get_reason(
    const flat_state& state,
    const LinearConstraintConjunction& guard,
    size_t action)
{
    return compute_reason(std::false_type(), state, guard, action, nullptr);
}

SuffCondAlternatives PolicyReasonerNNLPMasked::get_reason(
    const flat_state& state,
    const LinearConstraintConjunction& guard,
    size_t action)
{
    ApplicabilityConditioner applicability(*model_, info_);
    return compute_reason(
        std::true_type(),
        state,
        guard,
        action,
        &applicability);
}

} // namespace police::ic3::syntactic
