#include "police/verifiers/ic3/syntactic/policy_reasoner_lipa.hpp"

#include "police/base_types.hpp"
#include "police/cg_policy.hpp"
#include "police/cg_relaxation.hpp"
#include "police/compute_graph_factory.hpp"
#include "police/macros.hpp"
#include "police/nn_policy.hpp"
#include "police/option.hpp"
#include "police/variable_order_chooser.hpp"
#include "police/verifiers/ic3/syntactic/applicability_conditioner.hpp"
#include "police/verifiers/ic3/syntactic/variable_classification.hpp"

#include <memory>

#ifdef POLICE_EIGEN
#include "police/compute_graph.hpp"
#include "police/layer_bounds.hpp"
#include "police/utils/io.hpp"
#include <algorithm>
#endif

#define DEBUG_PRINTS 0

namespace police::ic3::syntactic {

#ifdef POLICE_EIGEN
namespace {
constexpr real_t EPSILON = 5e-5;
constexpr real_t LARGE_CONSTANT = 100000000.;

#ifndef NDEBUG
// vector<FeedForwardNeuralNetwork<>> NETS;
#endif
} // namespace
#endif

PolicyReasonerLIPABase::PolicyReasonerLIPABase(
    const Model* model,
    std::shared_ptr<CGPolicy> policy,
    [[maybe_unused]] const vector<size_t>& var_order,
    [[maybe_unused]] cg::RelaxationOptions options)
    : var_class_(classify_variables(*model))
    , label_to_index_(model->labels.size(), -1)
    , model_(model)
    , policy_(policy)
{
#ifdef POLICE_EIGEN

    std::shared_ptr<cg::NodeRelaxerFactory> relax_factory =
        std::make_shared<cg::NodeRelaxerFactory>();

    const auto& labels = policy->get_output();
    lipa_.reserve(labels.size());

    auto base_layers =
        relax_factory->create_recursive(policy_->get_compute_graph().get());

    for (size_t i = 0; i < labels.size(); ++i) {
        std::shared_ptr<cg::Node> compilation = nullptr;
        {
            Matrix weights(labels.size() - 1, labels.size());
            weights.setZero();
            LinVector biases(labels.size() - 1);
            biases.setZero();
            for (size_t j = 0, k = 0; j < labels.size(); ++j) {
                if (j != i) {
                    weights(k, j) = 1;
                    weights(k, i) = -1;
                    biases(k) = labels[i] < labels[j] ? EPSILON : 0;
                    ++k;
                }
            }
            auto comparison_layer = std::make_shared<cg::LinearLayer>(
                std::move(weights),
                std::move(biases));
            compilation = comparison_layer;
            comparison_layer_.push_back(std::move(comparison_layer));
        }

        std::shared_ptr<cg::Node> last_layer = compilation;
        {
            last_layer->set_successor(
                std::make_shared<cg::ReluLayer>(labels.size() - 1));
            last_layer = last_layer->successor();
        }

        {
            Matrix m(1, labels.size() - 1);
            m.setOnes();
            LinVector b(1);
            b.setZero();
            last_layer->set_successor(
                std::make_shared<cg::LinearLayer>(std::move(m), std::move(b)));
            last_layer = last_layer->successor();
        }

        vector<std::shared_ptr<cg::NodeRelaxer>> relaxation = base_layers;
        auto suffix = relax_factory->create_recursive(compilation.get());
        relaxation.insert(relaxation.end(), suffix.begin(), suffix.end());
        lipa_.push_back(
            std::make_shared<cg::ComputeGraphRelaxation>(
                std::move(relaxation),
                options));

        label_to_index_[labels[i]] = i;

#ifndef NDEBUG
        // NETS.push_back(std::move(net));
#endif
    }

    vector<size_t> rank(var_order.size());
    for (size_t order = 0; order < var_order.size(); ++order) {
        rank[var_order[order]] = order;
    }
    vector<std::pair<int, size_t>> vars;
    for (size_t i = 0; i < policy_->get_input().size(); ++i) {
        vars.emplace_back(-static_cast<int>(rank[policy_->get_input()[i]]), i);
    }
    std::sort(vars.begin(), vars.end());
    for (const auto& [_, idx] : vars) {
        input_.push_back(idx);
    }

#else

    POLICE_MISSING_DEPENDENCY("eigen");

#endif
}

template <bool ApplicabilityMasked>
SuffCondAlternatives PolicyReasonerLIPABase::compute_reason(
    [[maybe_unused]] const flat_state& state,
    [[maybe_unused]] const LinearConstraintConjunction&,
    [[maybe_unused]] size_t label,
    [[maybe_unused]] ApplicabilityConditioner* applicability)
{
#if POLICE_EIGEN

    assert(
        label < label_to_index_.size() &&
        label_to_index_[label] < lipa_.size());

#if DEBUG_PRINTS
    {
        std::cout << "Get reason for label " << label << std::flush
                  << " (index=" << label_to_index_[label]
                  << " or=" << lipa_.size() << ")..." << std::endl;
        vector<real_t> vals;
        for (const auto var : policy_->get_input()) {
            vals.push_back(static_cast<real_t>(state[var]));
        }
        const auto res = policy_->get_compute_graph()->operator()(vals);
        std::cout << "state=" << print_sequence(state) << std::endl;
        std::cout << "=> " << print_sequence(res) << std::endl;
        std::cout << "(size=" << res.size() << ")" << std::endl;
    }
#endif

    // get lipa analyzer
    const auto label_index = label_to_index_[label];
    auto& analyzer = *lipa_[label_index];

    // prepare input bounds
    const auto& vars = policy_->get_input();
    LayerBounds input_bounds(vars.size());
    for (size_t i = 0; i < vars.size(); ++i) {
        const auto& var = vars[i];
        input_bounds.set_bounds(
            i,
            static_cast<real_t>(state[var]),
            static_cast<real_t>(state[var]));
    }

    auto& biases = comparison_layer_[label_index]->get_biases();
    const auto& labels = policy_->get_output();
    auto sync_applicability_status = [&]() {
        assert(applicability != nullptr);
        for (size_t i = 0, j = 0; i < labels.size(); ++i) {
            if (i != label_index) {
                biases[j] = EPSILON * (label < labels[i]) -
                            (!(*applicability)[labels[i]]) * LARGE_CONSTANT;
                ++j;
            }
        }
    };

#ifndef NDEBUG
    if constexpr (ApplicabilityMasked) {
        sync_applicability_status();
    }

    // if (analyzer.compute_bounds(input_bounds).lb(0) <= 0.) {
    //     vector<real_t> vals;
    //     auto get_vals = [&]() {
    //         vals.clear();
    //         for (size_t var : policy_->get_input()) {
    //             vals.push_back(static_cast<real_t>(state[var]));
    //         }
    //     };
    //
    //     get_vals();
    //     std::cout << "|> input: " << print_sequence(vals) << std::endl;
    //     std::cout << "|> output: " << print_sequence(policy_->get_output())
    //               << std::endl;
    //
    //     std::cout << "|> net output: "
    //               << print_sequence(policy_->get_nn()(vals)) << std::endl;
    //     std::cout << "|> label: " << label << " (index: " << label_index <<
    //     ")"
    //               << std::endl;
    //     std::cout << "|> policy: " << (*policy_)(state) << std::endl;
    //
    //     get_vals();
    //     std::cout << "|> policy net" << std::endl;
    //     const auto& net_nn = policy_->get_nn();
    //     for (size_t i = 0; i + 1 < net_nn.layers.size(); ++i) {
    //         auto res = net_nn.layers[i](vals, Relu());
    //         std::cout << "// layer[" << i << "]: " << print_sequence(res)
    //                   << std::endl;
    //         vals.swap(res);
    //     }
    //     vals = net_nn.layers.back()(vals);
    //     std::cout << "// layer[" << (net_nn.layers.size() - 1)
    //               << "]: " << print_sequence(vals) << std::endl;
    //
    //     std::cout << "|> comparison net" << std::endl;
    //     get_vals();
    //     const auto& cmp_nn = NETS[label_index];
    //     for (size_t i = 0; i + 1 < cmp_nn.layers.size(); ++i) {
    //         auto output = cmp_nn.layers[i](vals, Relu());
    //         std::cout << "// layer[" << i << "]: " << print_sequence(output)
    //                   << std::endl;
    //         vals.swap(output);
    //     }
    //     vals = cmp_nn.layers.back()(vals);
    //     std::cout << "// layer[" << (cmp_nn.layers.size() - 1)
    //               << "]: " << print_sequence(vals) << std::endl;
    //
    //     std::cout << "|> layer bounds" << std::endl;
    //     for (size_t l = 0; l < analyzer.layers().size(); ++l) {
    //         const auto& layer = analyzer.layers()[l];
    //         std::cout << "L" << l << ": " << layer.get_bounds() << std::endl;
    //         std::cout << " >>";
    //         for (size_t n = 0; n < layer.num_outputs(); ++n) {
    //             std::cout << " n#" << n << "->"
    //                       << static_cast<int>(layer.get_relu_state(n));
    //         }
    //         std::cout << std::endl;
    //     }
    // }

    assert(analyzer.compute_bounds(input_bounds).lb(0) > 0.);
#endif

    SufficientCondition result;
    auto is_not_selected = [&]() {
        const auto bounds = analyzer.compute_lower_bounds(
            input_bounds,
            [](const auto& lbs, const auto&) { return lbs[0] > 0; });
        return bounds[0] > 0.;
    };

    for (const auto i : input_) {
        const auto var = vars[i];
        const auto type = model_->variables[var].type;
        input_bounds.set_bounds(
            i,
            static_cast<real_t>(type.get_lower_bound()),
            static_cast<real_t>(type.get_upper_bound()));

        if constexpr (ApplicabilityMasked) {
            applicability->assume_invalid(var);
            sync_applicability_status();
        }

        // std::cout << "var" << var << " (input " << i << ") domain: ["
        //           << type.get_lower_bound() << ", " << type.get_upper_bound()
        //           << "] ..." << std::endl;

        if (!is_not_selected()) {
            input_bounds.set_bounds(
                i,
                static_cast<real_t>(state[var]),
                static_cast<real_t>(type.get_upper_bound()));
            // std::cout << " -- need lb or ub..." << std::endl;
            if (state[var] > type.get_lower_bound() && is_not_selected()) {
                // std::cout << " --> need lb" << std::endl;
                result.emplace_back(var, VariableCondition::LOWER_BOUND);
            } else {
                // std::cout << " -- need lb and ub?..." << std::endl;
                input_bounds.set_bounds(
                    i,
                    static_cast<real_t>(type.get_lower_bound()),
                    static_cast<real_t>(state[var]));
                if (state[var] < type.get_upper_bound() && is_not_selected()) {
                    // std::cout << " --> need ub" << std::endl;
                    result.emplace_back(var, VariableCondition::UPPER_BOUND);
                } else {
                    // std::cout << " --> need lb and ub" << std::endl;
                    input_bounds.set_bounds(
                        i,
                        static_cast<real_t>(state[var]),
                        static_cast<real_t>(state[var]));
                    result.emplace_back(var, VariableCondition::EQUALITY);
                    if constexpr (ApplicabilityMasked) {
                        applicability->revert_last_assumption();
                    }
                }
            }
        }
    }

    return {std::move(result)};

#else
    POLICE_MISSING_DEPENDENCY("eigen");
#endif
}

SuffCondAlternatives PolicyReasonerLIPA::get_reason(
    const flat_state& state,
    const LinearConstraintConjunction& guard,
    size_t action)
{
    return compute_reason<false>(state, guard, action, nullptr);
}

SuffCondAlternatives PolicyReasonerLIPAMasked::get_reason(
    const flat_state& state,
    const LinearConstraintConjunction& guard,
    size_t action)
{
    ApplicabilityConditioner conditioner(*this->model_, info_);
    return compute_reason<true>(state, guard, action, &conditioner);
}

namespace {
PointerOption<PolicyReasoner> _option(
    "lipa",
    [](const Arguments& args) -> std::shared_ptr<PolicyReasoner> {
        const Model& model = args.get_model();
        const auto get_order = args.get_ptr<VariableOrderChooser>("order");
        auto order = get_order->get_variable_order();
        std::cout << "Policy reasoner's greedy variable ordering: "
                  << print_sequence(order) << std::endl;
        cg::RelaxationOptions options;
        options.max_steps = args.get<int>("max_steps");
        options.max_total_time = args.get<double>("max_time");
        options.optimize_only_last = args.get<bool>("optimize_only_last_layer");
        std::shared_ptr<CGPolicy> policy = nullptr;
        if (args.has_nn_policy()) {
            std::shared_ptr<NeuralNetworkPolicy> nn_policy = args.nn_policy;
            auto g = cg::from_ffnn(nn_policy->get_nn());
            auto p =
                cg::post_process(g, nn_policy->get_input(), model.variables);
            policy = std::make_shared<CGPolicy>(
                std::move(p.cg),
                std::move(p.inputs),
                nn_policy->get_output(),
                model.labels.size());
        } else if (args.has_cg_policy()) {
            policy = args.cg_policy;
        } else {
            POLICE_RUNTIME_ERROR(
                "lipa policy reasoner currently supports neural network and "
                "compute graph policies only");
        }
        if (args.applicability_masking) {
            return std::make_shared<PolicyReasonerLIPAMasked>(
                &model,
                std::move(policy),
                std::move(order),
                options);
        }
        return std::make_shared<PolicyReasonerLIPA>(
            &model,
            std::move(policy),
            std::move(order),
            options);
    },
    [](ArgumentsDefinition& defs) {
        defs.add_ptr_argument<VariableOrderChooser>(
            "order",
            "Order in which to remove variables",
            "default");
        defs.add_argument<int>(
            "max_steps",
            "Maximal number of optimization iterations",
            "1000");
        defs.add_argument<double>(
            "max_time",
            "Maximal time spent on LIPA",
            "600");
        defs.add_argument<bool>(
            "optimize_only_last_layer",
            "Optimize bounds of last layer only",
            "true");
    });
} // namespace

} // namespace police::ic3::syntactic

#undef DEBUG_PRINTS
