#include "../ffnn_generator.hpp"
#include "../numbers.hpp"

#include "police/ffnn_lp_encoder.hpp"
#include "police/lp.hpp"
#include "police/lp_gurobi.hpp"
#include "police/lp_z3.hpp"
#include "police/nnlp_bnb.hpp"
#include "police/nnlp_lp.hpp"
#include "police/nnlp_smt.hpp"
#include "police/smt_z3.hpp"
#include "police/storage/ffnn.hpp"
#include "police/utils/rng.hpp"

#include <catch2/catch.hpp>
#include <memory>

namespace {
[[maybe_unused]]
constexpr int NUM_TESTS = 10;
[[maybe_unused]]
constexpr int NUM_INPUTS = 100;
[[maybe_unused]]
constexpr int INPUT_SIZE = 8;
[[maybe_unused]]
constexpr int OUTPUT_SIZE = 5;
[[maybe_unused]]
constexpr double DENSITY = 1;
[[maybe_unused]]
constexpr int SEED = 1734;
[[maybe_unused]]
constexpr int MIN_VALUE = -100;
[[maybe_unused]]
constexpr int MAX_VALUE = 100;

using namespace police;
} // namespace

#define CREATE_TEST(Factory, Layers, Neurons)                                  \
    TEST_CASE(                                                                 \
        "Test NN NNLP encoding - " #Layers "x" #Neurons " - " #Factory,        \
        "[nnlp][nn-encoding]")                                                 \
    {                                                                          \
        auto net = GENERATE(take(                                              \
            NUM_TESTS,                                                         \
            GeneratorWrapper<FeedForwardNeuralNetwork<>>(                      \
                std::make_unique<FFNNGenerator>(                               \
                    INPUT_SIZE,                                                \
                    OUTPUT_SIZE,                                               \
                    Layers,                                                    \
                    Neurons,                                                   \
                    DENSITY))));                                               \
        auto lp = Factory()->make_unique();                                    \
        vector<police::size_t> invars(INPUT_SIZE);                              \
        for (auto i = 0u; i < invars.size(); ++i) {                            \
            invars[i] = i;                                                     \
            lp->add_variable(BoundedIntType(MIN_VALUE, MAX_VALUE));            \
        }                                                                      \
        const auto out_var = encode_ffnn_in_lp(*lp, net, invars);              \
        RNG rng(SEED);                                                         \
        for (int i = 0; i < NUM_INPUTS; ++i) {                                 \
            lp->push_snapshot();                                               \
            vector<double> in(INPUT_SIZE);                                     \
            for (int j = in.size() - 1; j >= 0; --j) {                         \
                in[j] = rng(MIN_VALUE, MAX_VALUE);                             \
                LinearConstraint eq(LinearConstraint::EQUAL);                  \
                eq.insert(j, 1.);                                              \
                eq.rhs = in[j];                                                \
                lp->add_constraint(eq);                                        \
            }                                                                  \
            const auto status = lp->solve();                                   \
            REQUIRE(status == NNLP::SAT);                                      \
            const auto model = lp->get_model();                                \
            const auto out = net(in);                                          \
            for (int i = 0; i < OUTPUT_SIZE; ++i) {                            \
                CHECK_THAT(                                                    \
                    static_cast<real_t>(model[out_var + i]),                   \
                    Catch::Matchers::WithinAbs(out[i], PRECISION));            \
            }                                                                  \
            lp->pop_snapshot();                                                \
        }                                                                      \
    }

#define CREATE_TESTS(Factory)                                                  \
    CREATE_TEST(Factory, 2, 8)                                                 \
    CREATE_TEST(Factory, 2, 16)

#if POLICE_Z3
namespace {
std::shared_ptr<NNLPSMTFactory> Z3_NNLP()
{
    static Z3SMTFactory z3;
    return std::make_shared<NNLPSMTFactory>(&z3, false);
}

std::shared_ptr<NNLPLPFactory> Z3_NNLP_LP()
{
    static Z3LPFactory z3;
    return std::make_shared<NNLPLPFactory>(&z3, false);
}
} // namespace

CREATE_TESTS(Z3_NNLP)
CREATE_TESTS(Z3_NNLP_LP)
#endif

#if POLICE_GUROBI
namespace {
std::shared_ptr<NNLPLPFactory> Gurobi_NNLP()
{
    static GurobiLPFactory gurobi;
    return std::make_shared<NNLPLPFactory>(&gurobi, false);
}

#if POLICE_MARABOU
std::shared_ptr<NNLPLPFactory> Gurobi_NNLP_Preprocessed()
{
    static GurobiLPFactory gurobi;
    return std::make_shared<NNLPLPFactory>(&gurobi, true);
}
#endif
} // namespace

CREATE_TESTS(Gurobi_NNLP)

#if POLICE_MARABOU
CREATE_TESTS(Gurobi_NNLP_Preprocessed)
#endif
#endif

#if POLICE_MARABOU
namespace {
std::shared_ptr<NNLPFactory> Marabou_NNLP()
{
    static MarabouLPFactory factory;
    return std::make_shared<NNLPBranchNBoundFactory>(&factory);
}
} // namespace

CREATE_TESTS(Marabou_NNLP);
#endif

#undef CREATE_TEST
#undef CREATE_TESTS
