#include "../ffnn_generator.hpp"

#include "models/simple_counter.hpp"
#include "police/lp_gurobi.hpp"
#include "police/lp_z3.hpp"
#include "police/nn_policy.hpp"
#include "police/nnlp_bnb.hpp"
#include "police/nnlp_lp.hpp"
#include "police/nnlp_marabou.hpp"
#include "police/nnlp_smt.hpp"
#include "police/smt_z3.hpp"
#include "police/storage/ffnn.hpp"
#include "police/storage/flat_state.hpp"
#include "police/successor_generator/successor_generator.hpp"
#include "police/verifiers/epic3/cube.hpp"
#include "police/verifiers/epic3/engines/sat_based/generalizer/greedy_generalizer.hpp"
#include "police/verifiers/epic3/engines/sat_based/generalizer/unsat_core_generalizer.hpp"
#include "police/verifiers/epic3/engines/sat_based/pic3sat_label.hpp"
#include "police/verifiers/epic3/engines/sat_based/pic3sat_singleton.hpp"
#include "police/verifiers/epic3/epic3.hpp"
#include "police/verifiers/epic3/start_generator.hpp"

#include <catch2/catch.hpp>

namespace {

using namespace police;

[[maybe_unused]]
constexpr unsigned COUNT_TO = 4;
constexpr unsigned NEURONS = 8;
constexpr unsigned LAYERS = 2;
constexpr unsigned VARIABLES = 1;
constexpr unsigned ACTIONS = 2;
constexpr double DENSITY = 1.;
[[maybe_unused]]
constexpr bool PRUNE_EDGES = true;

police::FeedForwardNeuralNetwork<> get_safe_net(int plan_length)
{
    static police::FeedForwardNeuralNetwork<> net;
    static int cached_plan_length = 0;
    if (cached_plan_length != plan_length) {
        cached_plan_length = plan_length;
        FFNNGenerator gen(VARIABLES, ACTIONS, LAYERS, NEURONS, DENSITY);
        for (;;) {
            bool sat = true;
            for (int val = 0; sat && val + 1 < plan_length; ++val) {
                vector<real_t> state({(real_t)val});
                const auto out = gen.get()(state);
                sat = out[0] >= out[1];
            }
            if (sat) {
                vector<real_t> state({(real_t)plan_length - 1});
                const auto out = gen.get()(state);
                if (out[0] < out[1]) {
                    net = gen.get();
                    return gen.get();
                }
            }
            gen.next();
        }
    }
    return net;
}

police::FeedForwardNeuralNetwork<> get_unsafe_net(int plan_length)
{
    return get_safe_net(plan_length + 1);
}

[[maybe_unused]]
police::NeuralNetworkPolicy<> get_unsafe_policy(int len)
{
    return {get_unsafe_net(len), {0u}, {0u, 1u}, 2};
}

[[maybe_unused]]
police::NeuralNetworkPolicy<> get_safe_policy(int len)
{
    return {get_safe_net(len), {0u}, {0u, 1u}, 2};
}

struct SuccessorGenerator {
    vector<successor_generator::Successor>
    operator()(const flat_state& state) const
    {
        flat_state succ(state);
        vector<real_t> in;
        in.push_back(static_cast<real_t>(state[0]));
        const auto out = net->operator()(in);
        if (out[0] >= out[1]) {
            if (state[0] < Value(count)) {
                succ[0] = Value(static_cast<int_t>(state[0]) + 1);
                return {{std::move(succ), 0}};
            } else {
                return {};
            }
        } else {
            if (state[0] > Value(0)) {
                succ[0] = Value(static_cast<int_t>(state[0]) - 1);
                return {{std::move(succ), 0}};
            } else {
                return {};
            }
        }
    }

    const FeedForwardNeuralNetwork<>* net;
    int count;
};

} // namespace

#define GEN_TEST_(Factory, Tags, Silent, Params, Name, Gen, Skip)              \
    TEST_CASE(                                                                 \
        "Reason gen and SAT refinement - silent " #Silent Name " - " #Factory, \
        "[ic3][reason][simple_counter]" #Tags)                                 \
    {                                                                          \
        if (Skip) return;                                                      \
        const int count_to = COUNT_TO;                                         \
        const bool silent = Silent;                                            \
        auto model = simple_counter_model(count_to, silent);                   \
        auto goal = simple_counter_goal(count_to);                             \
        auto policy = get_unsafe_policy(count_to);                             \
                                                                               \
        std::shared_ptr<Z3SMTFactory> smt = std::make_shared<Z3SMTFactory>();  \
        auto nnlp = Factory();                                                 \
        auto sat = Params;                                                     \
                                                                               \
        auto gen = Gen;                                                        \
                                                                               \
        auto start_smt = smt->make_shared();                                   \
        start_smt->add_variables(model.variables);                             \
        start_smt->add_constraint(                                             \
            expressions::Variable(0) == Value(0) &&                            \
            expressions::Variable(1) == Value(0));                             \
        auto start = epic3::StartGenerator(std::move(start_smt));              \
                                                                               \
        for (int frame = 1; frame < count_to; ++frame) {                       \
            epic3::Cube state;                                                 \
            state.set(0, epic3::Interval(Value(count_to - frame)));            \
            state.set(1, epic3::Interval(Value(0)));                           \
            REQUIRE(!sat.is_blocked(state, frame).first);                      \
            for (int start = 0; start < count_to - frame; ++start) {           \
                state.set(0, epic3::Interval(Value(start)));                   \
                REQUIRE(sat.is_blocked(state, frame).first);                   \
                gen(sat, state, frame);                                        \
                REQUIRE(state.get(0).ub <= count_to - frame - 1);              \
                state.set(0, epic3::Interval(Value(start)));                   \
                state.set(1, epic3::Interval(Value(0)));                       \
                sat.set_blocked(state, frame);                                 \
            }                                                                  \
            if (Silent) {                                                      \
                for (int start = 0; start < count_to - frame - 1; ++start) {   \
                    state.set(0, epic3::Interval(Value(start)));               \
                    REQUIRE(!sat.is_blocked(state, frame + 1).first);          \
                }                                                              \
                state.set(1, epic3::Interval(Value(1)));                       \
                for (int start = 0; start < count_to - frame; ++start) {       \
                    state.set(0, epic3::Interval(Value(start)));               \
                    REQUIRE(sat.is_blocked(state, frame).first);               \
                    gen(sat, state, frame);                                    \
                    REQUIRE(state.get(0).ub <= count_to - frame - 1);          \
                    state.set(0, epic3::Interval(Value(start)));               \
                    state.set(1, epic3::Interval(Value(1)));                   \
                    sat.set_blocked(state, frame);                             \
                }                                                              \
            }                                                                  \
        }                                                                      \
    }

#define SOLVABLE_TEST__(Factory, Tags, Silent, Params, Name)                   \
    TEST_CASE(                                                                 \
        "Simple counter model - silent " #Silent Name                          \
        " - has a plan: " #Factory,                                            \
        "[ic3][simple_counter]" #Tags)                                         \
    {                                                                          \
        const int count_to = COUNT_TO;                                         \
        const bool silent = Silent;                                            \
        auto model = simple_counter_model(count_to, silent);                   \
        auto goal = simple_counter_goal(count_to);                             \
        auto policy = get_unsafe_policy(count_to);                             \
                                                                               \
        std::shared_ptr<Z3SMTFactory> smt = std::make_shared<Z3SMTFactory>();  \
        auto nnlp = Factory();                                                 \
        auto sat = Params;                                                     \
                                                                               \
        auto gen = epic3::GreedyGeneralizer(                                   \
            &model.variables,                                                  \
            [](const auto& cube) {                                             \
                return !cube.has(0) ||                                         \
                       static_cast<int_t>(cube.get(0).ub) >= count_to;         \
            },                                                                 \
            {0, 1},                                                            \
            {true, true});                                                     \
                                                                               \
        auto start_smt = smt->make_shared();                                   \
        start_smt->add_variables(model.variables);                             \
        start_smt->add_constraint(                                             \
            expressions::Variable(0) == Value(0) &&                            \
            expressions::Variable(1) == Value(0));                             \
        auto start = epic3::StartGenerator(std::move(start_smt));              \
                                                                               \
        auto ic3 = epic3::ExplicitStateIC3smt<                                 \
            SuccessorGenerator,                                                \
            decltype(sat),                                                     \
            std::decay_t<decltype(gen)>>(                                      \
            &model.variables,                                                  \
            expressions::Variable(0) == Value(count_to),                       \
            start,                                                             \
            SuccessorGenerator(&policy.get_nn(), count_to),                    \
            sat,                                                               \
            gen,                                                               \
            false);                                                            \
        auto path = ic3();                                                     \
        REQUIRE(path.has_value());                                             \
    }

#define SOLVABLE_TEST_(Factory, Tags, Silent, Params, Name, SkipCore)          \
    SOLVABLE_TEST__(Factory, Tags, Silent, Params, Name)                       \
    GEN_TEST_(                                                                 \
        Factory,                                                               \
        Tags,                                                                  \
        Silent,                                                                \
        Params,                                                                \
        Name " - greedy",                                                      \
        epic3::GreedyGeneralizer(                                              \
            &model.variables,                                                  \
            [](const auto& cube) {                                             \
                return !cube.has(0) ||                                         \
                       static_cast<int_t>(cube.get(0).ub) >= count_to;         \
            },                                                                 \
            {0, 1},                                                            \
            {true, true}),                                                     \
        false)                                                                 \
    GEN_TEST_(                                                                 \
        Factory,                                                               \
        Tags,                                                                  \
        Silent,                                                                \
        Params,                                                                \
        Name " - core",                                                        \
        epic3::UnsatCoreGeneralizer(                                           \
            &model.variables,                                                  \
            [](const auto& cube) {                                             \
                return !cube.has(0) ||                                         \
                       static_cast<int_t>(cube.get(0).ub) >= count_to;         \
            }),                                                                \
        SkipCore)

#define SOLVABLE_TEST(Factory, Tags, Silent, SkipCore)                         \
    SOLVABLE_TEST_(                                                            \
        Factory,                                                               \
        Tags,                                                                  \
        Silent,                                                                \
        epic3::PIC3SatEdgeIndividual(                                          \
            epic3::PIC3SatParameters(                                          \
                smt,                                                           \
                nnlp,                                                          \
                false,                                                         \
                true,                                                          \
                PRUNE_EDGES,                                                   \
                false),                                                        \
            model,                                                             \
            policy,                                                            \
            {},                                                                \
            goal),                                                             \
        "",                                                                    \
        SkipCore)                                                              \
                                                                               \
    SOLVABLE_TEST_(                                                            \
        Factory,                                                               \
        Tags,                                                                  \
        Silent,                                                                \
        epic3::PIC3SatEdgeIndividual(                                          \
            epic3::PIC3SatParameters(                                          \
                smt,                                                           \
                nnlp,                                                          \
                true,                                                          \
                true,                                                          \
                PRUNE_EDGES,                                                   \
                false),                                                        \
            model,                                                             \
            policy,                                                            \
            {},                                                                \
            goal),                                                             \
        " - app_filter",                                                       \
        SkipCore)

#define SOLVABLE_TEST_SINGLETON(Factory, Tags, Silent)                         \
    SOLVABLE_TEST_(                                                            \
        Factory,                                                               \
        Tags "[ic3sat-singleton]",                                             \
        Silent,                                                                \
        epic3::PIC3SatSingleton(                                               \
            epic3::PIC3SatSingletonParameters(smt, nnlp, false, false),        \
            model,                                                             \
            policy,                                                            \
            {},                                                                \
            goal),                                                             \
        " - singleton",                                                        \
        false)

#define UNSOLVABLE_TEST_(Factory, Tags, Silent, Params, Name)                  \
    TEST_CASE(                                                                 \
        "Simple counter model - silent - " #Silent Name                        \
        " - doesn't have a plan: " #Factory,                                   \
        "[ic3]" #Tags)                                                         \
    {                                                                          \
        const int count_to = COUNT_TO + 1;                                     \
        const bool silent = Silent;                                            \
        auto model = simple_counter_model(count_to, silent);                   \
        auto goal = simple_counter_goal(count_to);                             \
        auto policy = get_safe_policy(count_to);                               \
                                                                               \
        std::shared_ptr<Z3SMTFactory> smt = std::make_shared<Z3SMTFactory>();  \
        auto nnlp = Factory();                                                 \
        auto sat = Params;                                                     \
                                                                               \
        auto gen = epic3::GreedyGeneralizer(                                   \
            &model.variables,                                                  \
            [](const auto& cube) {                                             \
                return !cube.has(0) ||                                         \
                       static_cast<int_t>(cube.get(0).ub) >= count_to;         \
            },                                                                 \
            {0, 1},                                                            \
            {true, true});                                                     \
                                                                               \
        auto start_smt = smt->make_shared();                                   \
        start_smt->add_variables(model.variables);                             \
        start_smt->add_constraint(                                             \
            expressions::Variable(0) == Value(0) &&                            \
            expressions::Variable(1) == Value(0));                             \
        auto start = epic3::StartGenerator(std::move(start_smt));              \
                                                                               \
        auto ic3 = epic3::ExplicitStateIC3smt<                                 \
            SuccessorGenerator,                                                \
            decltype(sat),                                                     \
            std::decay_t<decltype(gen)>>(                                      \
            &model.variables,                                                  \
            expressions::Variable(0) == Value(count_to),                       \
            start,                                                             \
            SuccessorGenerator(&policy.get_nn(), count_to),                    \
            sat,                                                               \
            gen,                                                               \
            false);                                                            \
        auto path = ic3();                                                     \
        REQUIRE(!path.has_value());                                            \
    }

#define UNSOLVABLE_TEST(Factory, Tags, Silent)                                 \
    UNSOLVABLE_TEST_(                                                          \
        Factory,                                                               \
        Tags,                                                                  \
        Silent,                                                                \
        epic3::PIC3SatEdgeIndividual(                                          \
            epic3::PIC3SatParameters(                                          \
                smt,                                                           \
                nnlp,                                                          \
                false,                                                         \
                true,                                                          \
                PRUNE_EDGES,                                                   \
                false),                                                        \
            model,                                                             \
            policy,                                                            \
            {},                                                                \
            goal),                                                             \
        "")                                                                    \
                                                                               \
    UNSOLVABLE_TEST_(                                                          \
        Factory,                                                               \
        Tags,                                                                  \
        Silent,                                                                \
        epic3::PIC3SatEdgeIndividual(                                          \
            epic3::PIC3SatParameters(                                          \
                smt,                                                           \
                nnlp,                                                          \
                true,                                                          \
                true,                                                          \
                PRUNE_EDGES,                                                   \
                false),                                                        \
            model,                                                             \
            policy,                                                            \
            {},                                                                \
            goal),                                                             \
        " - app_filter")

#define UNSOLVABLE_TEST_SINGLETON(Factory, Tags, Silent)                       \
    UNSOLVABLE_TEST_(                                                          \
        Factory,                                                               \
        Tags "[ic3sat-singleton]",                                             \
        Silent,                                                                \
        epic3::PIC3SatSingleton(                                               \
            epic3::PIC3SatSingletonParameters(smt, nnlp, false, false),        \
            model,                                                             \
            policy,                                                            \
            {},                                                                \
            goal),                                                             \
        " - singleton")

#define TESTS(Factory, Tags, SkipCore)                                         \
    SOLVABLE_TEST(Factory, Tags, false, SkipCore)                              \
    UNSOLVABLE_TEST(Factory, Tags, false)                                      \
    SOLVABLE_TEST(Factory, Tags, true, SkipCore)                               \
    UNSOLVABLE_TEST(Factory, Tags, true)

#define TESTS_SINGLETON(Factory, Tags)                                         \
    SOLVABLE_TEST_SINGLETON(Factory, Tags, false)                              \
    UNSOLVABLE_TEST_SINGLETON(Factory, Tags, false)                            \
    SOLVABLE_TEST_SINGLETON(Factory, Tags, true)                               \
    UNSOLVABLE_TEST_SINGLETON(Factory, Tags, true)

#if POLICE_Z3
namespace {
std::shared_ptr<NNLPSMTFactory> Z3_NNLP()
{
    static Z3SMTFactory z3;
    return std::make_shared<NNLPSMTFactory>(&z3, false);
}

std::shared_ptr<NNLPLPFactory> Z3_NNLP_LP()
{
    static Z3LPFactory z3;
    return std::make_shared<NNLPLPFactory>(&z3, false);
}
} // namespace

TESTS(Z3_NNLP, "[smt][z3]", false)
TESTS(Z3_NNLP_LP, "[lp][z3]", false)

#if POLICE_GUROBI
namespace {
std::shared_ptr<NNLPLPFactory> Gurobi_NNLP()
{
    static GurobiLPFactory gurobi;
    return std::make_shared<NNLPLPFactory>(&gurobi, false);
}
#if POLICE_MARABOU
std::shared_ptr<NNLPLPFactory> Gurobi_NNLP_Preprocessed()
{
    static GurobiLPFactory gurobi;
    return std::make_shared<NNLPLPFactory>(&gurobi, true);
}
#endif
} // namespace

TESTS(Gurobi_NNLP, "[lp][gurobi]", false)
TESTS_SINGLETON(Gurobi_NNLP, "[lp][gurobi]")

#if POLICE_MARABOU
TESTS(Gurobi_NNLP_Preprocessed, "[lp][gurobi][marabou]", true)
#endif

#endif

#if POLICE_MARABOU
namespace {
std::shared_ptr<NNLPFactory> Marabou_NNLP()
{
    static MarabouLPFactory factory;
    return std::make_shared<NNLPBranchNBoundFactory>(&factory);
}
} // namespace

TESTS(Marabou_NNLP, "[marabou]", true);
#endif
#endif

#undef GEN_TEST_
#undef SOLVABLE_TEST_
#undef SOLVABLE_TEST__
#undef UNSOLVABLE_TEST_
#undef SOLVABLE_TEST
#undef UNSOLVABLE_TEST
#undef TESTS
#undef TESTS_SINGLETON
#undef SOLVABLE_TEST_SINGLETON
#undef UNSOLVABLE_TEST_SINGLETON
