#include "police/base_types.hpp"
#include "police/marabou_preprocessor.hpp"
#include "nnlp_factories.hpp"
#include "nnlp_generator.hpp"

#include <catch2/catch.hpp>
#include <memory>
#include <type_traits>

namespace {
using namespace police;

[[maybe_unused]]
constexpr police::size_t LAYER_SIZE = 4;
[[maybe_unused]]
constexpr police::size_t NUM_LAYERS = 2;
[[maybe_unused]]
constexpr police::size_t INPUT_SIZE = 8;
[[maybe_unused]]
constexpr police::size_t OUTPUT_SIZE = 3;
[[maybe_unused]]
constexpr police::size_t NUM_NETWORKS = 4;
[[maybe_unused]]
constexpr police::size_t NUM_TEST_CASES = 4;
[[maybe_unused]]
constexpr police::size_t NUM_CONSTRAINTS = 8;
[[maybe_unused]]
constexpr police::size_t NUM_DISJUNCTIONS = 4;
[[maybe_unused]]
constexpr police::size_t MAX_DISJUNCTION_SIZE = 8;
[[maybe_unused]]
constexpr police::real_t DENSITY = 0.5;

[[maybe_unused]]
constexpr std::string test_name(std::size_t disj, const char* factory)
{
    return std::string("Random NNLP ") + (disj ? "with" : "without") +
           " disjunctions is solvable: " + factory;
}

[[maybe_unused]]
void require_preprocess(std::true_type, police::NNLPMarabouPreprocessor& lp)
{
    const bool feasible_in_preprocess = lp.preprocess();
    CHECK(feasible_in_preprocess);
}

[[maybe_unused]]
void require_preprocess(std::false_type, police::NNLP&)
{
}

} // namespace

#define GENERATE_RND_TEST_CASE(Factory, Tags, DISJUNCTIONS)                    \
    TEST_CASE(test_name(DISJUNCTIONS, #Factory), "[nnlp][int]" Tags)           \
    {                                                                          \
        auto test = GENERATE(take(                                             \
            NUM_TEST_CASES * NUM_NETWORKS,                                     \
            GeneratorWrapper<NNLPData>(std::make_unique<NNLPDataGenerator>(    \
                NUM_TEST_CASES,                                                \
                NUM_CONSTRAINTS,                                               \
                DISJUNCTIONS,                                                  \
                MAX_DISJUNCTION_SIZE,                                          \
                FFNNGenerator(                                                 \
                    INPUT_SIZE,                                                \
                    OUTPUT_SIZE,                                               \
                    NUM_LAYERS,                                                \
                    LAYER_SIZE,                                                \
                    DENSITY)))));                                              \
        auto lp = Factory();                                                   \
        test.encode(lp);                                                       \
        require_preprocess(                                                    \
            std::integral_constant<                                            \
                bool,                                                          \
                std::is_base_of_v<                                             \
                    police::NNLPMarabouPreprocessor,                            \
                    std::remove_cv_t<decltype(lp)>>>(),                        \
            lp);                                                               \
        const auto lp_status = lp.solve();                                     \
        if (lp_status != police::NNLP::Status::SAT) {                           \
        }                                                                      \
        REQUIRE(lp_status == police::NNLP::Status::SAT);                        \
    }

#define GENERATE_RND_TEST_CASES(Factory, Tags)                                 \
    GENERATE_RND_TEST_CASE(Factory, Tags, 0)                                   \
    GENERATE_RND_TEST_CASE(Factory, Tags "[disjunctions]", NUM_DISJUNCTIONS)

#if POLICE_Z3
GENERATE_RND_TEST_CASES(Z3_NNLP, "[smt][z3]")
GENERATE_RND_TEST_CASES(BnB_NNLP_Z3, "[bnb][smt][z3]")
#endif

#if POLICE_GUROBI
GENERATE_RND_TEST_CASES(Gurobi_NNLP, "[lp][gurobi]")
GENERATE_RND_TEST_CASES(BnB_NNLP_Gurobi, "[bnb][gurobi]")
#endif

#if POLICE_MARABOU
GENERATE_RND_TEST_CASES(MarabouLP, "[marabou]")
GENERATE_RND_TEST_CASES(BnB_NNLP_Marabou, "[bnb][marabou]")
#endif

#if POLICE_Z3 && POLICE_MARABOU
GENERATE_RND_TEST_CASES(Z3_NNLP_Preprocessing, "[smt][z3][marabou]")
#endif

#if POLICE_GUROBI && POLICE_MARABOU
GENERATE_RND_TEST_CASES(Gurobi_NNLP_Preprocessing, "[lp][gurobi][marabou]")
#endif

#undef GENERATE_RND_TEST_CASE
#undef GENERATE_RND_TEST_CASES
