#include "../ffnn_generator.hpp"

#include "./models/binary_shifter.hpp"
#include "police/base_types.hpp"
#include "police/expressions/variable.hpp"
#include "police/lp_gurobi.hpp"
#include "police/lp_z3.hpp"
#include "police/nn_policy.hpp"
#include "police/nnlp_bnb.hpp"
#include "police/nnlp_lp.hpp"
#include "police/nnlp_smt.hpp"
#include "police/smt_z3.hpp"
#include "police/storage/ffnn.hpp"
#include "police/storage/flat_state.hpp"
#include "police/successor_generator/successor_generator.hpp"
#include "police/verifiers/epic3/engines/sat_based/epic3_smt.hpp"
#include "police/verifiers/epic3/engines/sat_based/generalizer/greedy_generalizer.hpp"
#include "police/verifiers/epic3/engines/sat_based/generalizer/unsat_core_generalizer.hpp"
#include "police/verifiers/epic3/engines/sat_based/pic3sat_label.hpp"
#include "police/verifiers/epic3/engines/sat_based/pic3sat_singleton.hpp"
#include "police/verifiers/epic3/epic3.hpp"
#include "verifiers/epic3/start_generator.hpp"

#include <catch2/catch.hpp>
#include <iterator>
#include <memory>
#include <numeric>
#include <type_traits>

namespace {

using namespace police;

[[maybe_unused]]
constexpr unsigned BITS = 8;

constexpr unsigned NEURONS = 8;
constexpr unsigned LAYERS = 2;
constexpr unsigned ACTIONS = 2;
constexpr double DENSITY = 1.;

bool shifts_forward(const FeedForwardNeuralNetwork<>& net, int bit, int bits)
{
    vector<real_t> state(bits, 0);
    state[bit] = 1;
    const auto out = net(state);
    return out[0] < out[1];
}

bool shifts_forward_to(const FeedForwardNeuralNetwork<>& net, int bit, int bits)
{
    for (int i = 0; i < bit; ++i) {
        if (!shifts_forward(net, i, bits)) {
            return false;
        }
    }
    return true;
}

template <typename Condition>
police::FeedForwardNeuralNetwork<> get_net(int bits, Condition condition)
{
    FFNNGenerator gen(bits, ACTIONS, LAYERS, NEURONS, DENSITY);
    for (;;) {
        if (condition(gen.get())) {
            return gen.get();
        }
        gen.next();
    }
}

[[maybe_unused]]
police::FeedForwardNeuralNetwork<> get_safe_net(int bits)
{
    static int cached = 0;
    static FeedForwardNeuralNetwork<> net;
    if (cached != bits) {
        cached = bits;
        net = get_net(bits, [&](const auto& net) {
            return shifts_forward_to(net, bits - 2, bits) &&
                   !shifts_forward(net, bits - 2, bits);
        });
    }
    return net;
}

[[maybe_unused]]
police::FeedForwardNeuralNetwork<> get_unsafe_net(int bits)
{
    static int cached = 0;
    static FeedForwardNeuralNetwork<> net;
    if (cached != bits) {
        cached = bits;
        net = get_net(bits, [&](const auto& net) {
            return shifts_forward_to(net, bits - 1, bits);
        });
    }
    return net;
}

[[maybe_unused]]
police::NeuralNetworkPolicy<> to_policy(FeedForwardNeuralNetwork<>&& net)
{
    vector<police::size_t> input(net.input.size());
    std::iota(input.begin(), input.end(), 0);
    return {std::move(net), std::move(input), {0u, 1u}, 2};
}

struct SuccessorGenerator {
    vector<successor_generator::Successor>
    operator()(const flat_state& state) const
    {
        flat_state succ(state);
        vector<real_t> in;
        std::transform(
            state.begin(),
            state.begin() + bits,
            std::back_inserter(in),
            [](const auto& val) { return static_cast<real_t>(val); });
        auto bit = std::distance(
            state.begin(),
            std::find(state.begin(), state.begin() + bits, Value(1)));
        assert(bit < bits);
        const auto out = net->operator()(in);
        police::size_t action = 0;
        if (out[0] < out[1]) {
            if (bit == bits - 1) {
                return {};
            }
            succ[bit] = Value(0);
            succ[bit + 1] = Value(1);
        } else {
            if (bit == 0) {
                return {};
            }
            succ[bit] = Value(0);
            succ[bit - 1] = Value(1);
            action = 1;
        }
        return {{std::move(succ), action}};
    }

    const FeedForwardNeuralNetwork<>* net;
    int bits;
};

[[maybe_unused]]
expressions::Expression state_expression(unsigned bit, unsigned num_bits)
{
    expressions::Expression result =
        expressions::Variable(0) == Value((int)(bit == 0u));
    for (unsigned var = 1; var + 1 < num_bits; ++var) {
        result =
            result && expressions::Variable(var) == Value((int)(bit == var));
    }
    result = result && expressions::Variable(num_bits - 1) == Value(0);
    return result;
}

[[maybe_unused]]
vector<police::size_t> varsequence(unsigned num)
{
    vector<police::size_t> res(num);
    std::iota(res.begin(), res.end(), 0);
    return res;
}

} // namespace

#define CREATE_TEST_(                                                          \
    TestName,                                                                  \
    Tags,                                                                      \
    GenNet,                                                                    \
    Bits,                                                                      \
    Factory,                                                                   \
    Verify,                                                                    \
    Params,                                                                    \
    Generator)                                                                 \
    TEST_CASE(                                                                 \
        TestName ": " #Factory " - " #Generator,                               \
        "[ic3][binary_shifter]" Tags)                                          \
    {                                                                          \
        auto model = binary_shifter_model(Bits);                               \
        auto goal = binary_shifter_goal(Bits);                                 \
        auto policy = to_policy(GenNet(Bits));                                 \
                                                                               \
        std::shared_ptr<Z3SMTFactory> smt = std::make_shared<Z3SMTFactory>();  \
        auto nnlp = Factory();                                                 \
        auto sat = Params;                                                     \
                                                                               \
        auto gen = Generator(&model.variables, [](const auto& cube) {          \
            return !cube.has(Bits - 1) ||                                      \
                   static_cast<int_t>(cube.get(Bits - 1).ub) >= 1;             \
        });                                                                    \
                                                                               \
        auto start_smt = smt->make_shared();                                   \
        start_smt->add_variables(model.variables);                             \
        start_smt->add_constraint(state_expression(0, Bits));                  \
        auto start = epic3::StartGenerator(std::move(start_smt));              \
                                                                               \
        auto ic3 = epic3::ExplicitStateIC3smt<                                 \
            SuccessorGenerator,                                                \
            decltype(sat),                                                     \
            std::decay_t<decltype(gen)>>(                                      \
            &model.variables,                                                  \
            expressions::Variable(Bits - 1) == Value(true),                    \
            std::move(start),                                                  \
            SuccessorGenerator(&policy.get_nn(), Bits),                        \
            std::move(sat),                                                    \
            std::move(gen),                                                    \
            false);                                                            \
        auto path = ic3();                                                     \
        Verify(path);                                                          \
    }

#define CREATE_TEST(TestName, Tags, GenNet, Bits, Factory, Verify, Generator)  \
    CREATE_TEST_(                                                              \
        TestName,                                                              \
        Tags,                                                                  \
        GenNet,                                                                \
        Bits,                                                                  \
        Factory,                                                               \
        Verify,                                                                \
        epic3::PIC3SatEdgeIndividual(                                          \
            epic3::PIC3SatParameters(smt, nnlp, false, true, false, false),    \
            model,                                                             \
            policy,                                                            \
            {},                                                                \
            goal),                                                             \
        Generator)                                                             \
                                                                               \
    CREATE_TEST_(                                                              \
        TestName " (app-filter)",                                              \
        Tags,                                                                  \
        GenNet,                                                                \
        Bits,                                                                  \
        Factory,                                                               \
        Verify,                                                                \
        epic3::PIC3SatEdgeIndividual(                                          \
            epic3::PIC3SatParameters(smt, nnlp, true, true, false, false),     \
            model,                                                             \
            policy,                                                            \
            {},                                                                \
            goal),                                                             \
        Generator)

#define CREATE_TEST_SINGLETON(TestName, Tags, GenNet, Bits, Factory, Verify)   \
    CREATE_TEST_(                                                              \
        TestName " - singleton",                                               \
        Tags "[ic3sat-singleton]",                                             \
        GenNet,                                                                \
        Bits,                                                                  \
        Factory,                                                               \
        Verify,                                                                \
        epic3::PIC3SatSingleton(                                               \
            epic3::PIC3SatSingletonParameters(smt, nnlp, false, false),        \
            model,                                                             \
            policy,                                                            \
            {},                                                                \
            goal),                                                             \
        epic3::UnsatCoreGeneralizer)                                           \
                                                                               \
    CREATE_TEST_(                                                              \
        TestName " - singleton (app-filter)",                                  \
        Tags "[ic3sat-singleton]",                                             \
        GenNet,                                                                \
        Bits,                                                                  \
        Factory,                                                               \
        Verify,                                                                \
        epic3::PIC3SatSingleton(                                               \
            epic3::PIC3SatSingletonParameters(smt, nnlp, true, false),         \
            model,                                                             \
            policy,                                                            \
            {},                                                                \
            goal),                                                             \
        epic3::UnsatCoreGeneralizer)

#define SOLVABLE_TEST(Factory, Tags, Generalizer)                              \
    CREATE_TEST(                                                               \
        "Binary shifter model has a plan",                                     \
        Tags,                                                                  \
        get_unsafe_net,                                                        \
        BITS,                                                                  \
        Factory,                                                               \
        [](const auto& path) { REQUIRE(path.has_value()); },                   \
        Generalizer)

#define UNSOLVABLE_TEST(Factory, Tags, Generalizer)                            \
    CREATE_TEST(                                                               \
        "Binary shifter model does not have a plan",                           \
        Tags,                                                                  \
        get_safe_net,                                                          \
        (BITS),                                                                \
        Factory,                                                               \
        [](const auto& path) { REQUIRE(!path.has_value()); },                  \
        Generalizer)

#define SOLVABLE_TEST_SINGLETON(Factory, Tags)                                 \
    CREATE_TEST_SINGLETON(                                                     \
        "Binary shifter model has a plan",                                     \
        Tags,                                                                  \
        get_unsafe_net,                                                        \
        BITS,                                                                  \
        Factory,                                                               \
        [](const auto& path) { REQUIRE(path.has_value()); })

#define UNSOLVABLE_TEST_SINGLETON(Factory, Tags)                               \
    CREATE_TEST_SINGLETON(                                                     \
        "Binary shifter model does not have a plan",                           \
        Tags,                                                                  \
        get_safe_net,                                                          \
        (BITS),                                                                \
        Factory,                                                               \
        [](const auto& path) { REQUIRE(!path.has_value()); })

#define TESTS(Factory, Tags, Generalizer)                                      \
    SOLVABLE_TEST(Factory, Tags, Generalizer)                                  \
    UNSOLVABLE_TEST(Factory, Tags, Generalizer)

#define TESTS_UC(Factory, Tags)                                                \
    TESTS(Factory, Tags, epic3::UnsatCoreGeneralizer)

#define TESTS_G(Factory, Tags) TESTS(Factory, Tags, epic3::GreedyGeneralizer)

#define TESTS_x(Factory, Tags) TESTS_G(Factory, Tags) TESTS_UC(Factory, Tags)

#define TESTS_SINGLETON(Factory, Tags)                                         \
    SOLVABLE_TEST_SINGLETON(Factory, Tags)                                     \
    UNSOLVABLE_TEST_SINGLETON(Factory, Tags)

#if POLICE_Z3
namespace {
std::shared_ptr<NNLPSMTFactory> Z3_NNLP()
{
    static Z3SMTFactory z3;
    return std::make_shared<NNLPSMTFactory>(&z3, false);
}

std::shared_ptr<NNLPLPFactory> Z3_NNLP_LP()
{
    static Z3LPFactory z3;
    return std::make_shared<NNLPLPFactory>(&z3, false);
}
} // namespace

TESTS_x(Z3_NNLP, "[smt][z3]") TESTS_x(Z3_NNLP_LP, "[lp][z3]")
    TESTS_SINGLETON(Z3_NNLP_LP, "[lp][z3]")

#if POLICE_GUROBI
        namespace
{
    std::shared_ptr<NNLPLPFactory> Gurobi_NNLP()
    {
        static GurobiLPFactory gurobi;
        return std::make_shared<NNLPLPFactory>(&gurobi, false);
    }

#if POLICE_MARABOU
    std::shared_ptr<NNLPLPFactory> Gurobi_NNLP_Preprocessed()
    {
        static GurobiLPFactory gurobi;
        return std::make_shared<NNLPLPFactory>(&gurobi, true);
    }
#endif
} // namespace

TESTS_x(Gurobi_NNLP, "[lp][gurobi]")
    TESTS_SINGLETON(Gurobi_NNLP, "[lp][gurobi]")

#if POLICE_MARABOU
        TESTS_G(Gurobi_NNLP_Preprocessed, "[lp][gurobi][marabou]")
#endif
#endif

#if POLICE_MARABOU
            namespace
{
    std::shared_ptr<NNLPFactory> Marabou_NNLP()
    {
        static MarabouLPFactory factory;
        return std::make_shared<NNLPBranchNBoundFactory>(&factory);
    }
} // namespace

TESTS_G(Marabou_NNLP, "[marabou]");
#endif
#endif

#undef CREATE_TEST_
#undef CREATE_TEST
#undef SOLVABLE_TEST
#undef UNSOLVABLE_TEST
#undef TESTS
#undef TESTS_G
#undef TESTS_UC
#undef TESTS_x
#undef CREATE_TEST_SINGLETON
#undef SOLVABLE_TEST_SINGLETON
#undef UNSOLVABLE_TEST_SINGLETON
#undef TESTS_SINGLETON
