#include "../ffnn_generator.hpp"
#include "./models/binary_flip.hpp"

#include "police/base_types.hpp"
#include "police/lp_gurobi.hpp"
#include "police/lp_z3.hpp"
#include "police/nn_policy.hpp"
#include "police/nnlp_bnb.hpp"
#include "police/nnlp_lp.hpp"
#include "police/nnlp_smt.hpp"
#include "police/smt_z3.hpp"
#include "police/storage/ffnn.hpp"
#include "police/storage/unordered_set.hpp"
#include "police/storage/vector.hpp"
#include "police/utils/ranges.hpp"
#include "police/verifiers/epic3/engines/sat_based/epic3_smt.hpp"
#include "police/verifiers/epic3/engines/sat_based/generalizer/greedy_generalizer.hpp"
#include "police/verifiers/epic3/engines/sat_based/generalizer/unsat_core_generalizer.hpp"
#include "police/verifiers/epic3/engines/sat_based/pic3sat_label.hpp"
#include "police/verifiers/epic3/engines/sat_based/pic3sat_singleton.hpp"
#include "police/verifiers/epic3/start_generator.hpp"

#include <algorithm>
#include <catch2/catch.hpp>
#include <iterator>
#include <memory>
#include <numeric>

namespace {
using namespace police;

[[maybe_unused]]
constexpr unsigned BITS = 5;
constexpr unsigned LAYERS = 2;
constexpr unsigned NEURONS = 8;
constexpr double DENSITY = 1.;

template <typename Condition>
police::FeedForwardNeuralNetwork<> get_net(int bits, Condition condition)
{
    FFNNGenerator gen(bits, 2 * bits, LAYERS, NEURONS, DENSITY);
    for (;;) {
        if (condition(gen.get())) {
            return gen.get();
        }
        gen.next();
    }
}

[[maybe_unused]]
bool flips_correct_bits_filtered(
    const FeedForwardNeuralNetwork<>& net,
    vector<real_t> state,
    int n)
{
    for (unsigned i = n; i < (state.size() + 1) / 2; ++i) {
        const auto output = net(state);
        vector<std::pair<double, unsigned>> ranked;
        for (unsigned j = 0; j < output.size(); ++j) {
            ranked.emplace_back(-output[j], j);
        }
        std::sort(ranked.begin(), ranked.end());
        for (const auto& [_, a] : ranked) {
            if (a % 2) {
                unsigned b = (a - 1) / 2;
                if (state[b]) {
                    return false;
                }
            } else {
                unsigned b = a / 2;
                if (!state[b]) {
                    if (b % 2) return false;
                    state[b] = 1;
                    break;
                }
            }
        }
    }
    return true;
}

[[maybe_unused]]
bool flips_correct_bits(
    const FeedForwardNeuralNetwork<>& net,
    vector<real_t> state,
    int n)
{
    for (unsigned i = n; i < (state.size() + 1) / 2; ++i) {
        const auto output = net(state);
        auto a_it = std::max_element(output.begin(), output.end());
        unsigned a = std::distance(output.begin(), a_it);
        if (a % 4) {
            return false;
        }
        unsigned b = a / 2;
        if (state[b] > 0.) {
            return false;
        }
        state[b] = 1.;
    }
    return true;
}

template <typename Cond>
bool flips_correct_bits2(
    const FeedForwardNeuralNetwork<>& net,
    int bits,
    Cond&& cond)
{
    vector<real_t> state(bits, 0);
    if (!cond(net, state, 0)) {
        return false;
    }
    for (int i = 0; i < bits; i += 2) {
        state[i] = 1;
        if (!cond(net, state, 1)) {
            return false;
        }
        state[i] = 0;
    }
    return true;
}

[[maybe_unused]]
police::FeedForwardNeuralNetwork<> get_safe_net(int bits)
{
    static int cached = 0;
    static FeedForwardNeuralNetwork<> net;
    if (cached != bits) {
        cached = bits;
        net = get_net(bits, [&](const auto& net) {
            return flips_correct_bits2(net, bits, flips_correct_bits);
        });
    }
    return net;
}

[[maybe_unused]]
police::FeedForwardNeuralNetwork<> get_safe_net_filtered(int bits)
{
    static int cached = 0;
    static FeedForwardNeuralNetwork<> net;
    if (cached != bits) {
        cached = bits;
        net = get_net(bits, [&](const auto& net) {
            return flips_correct_bits2(net, bits, flips_correct_bits_filtered);
        });
    }
    return net;
}

[[maybe_unused]]
bool flips_incorrect_bits_filtered(
    const FeedForwardNeuralNetwork<>& net,
    vector<real_t> state,
    int n)
{
    unordered_set<vector<bool>> trace;
    auto to_bool = [](const auto& state) {
        vector<bool> vec(state.size());
        std::transform(
            state.begin(),
            state.end(),
            vec.begin(),
            [](const real_t val) { return val > 0.; });
        return vec;
    };
    for (unsigned i = n; i < (state.size() + 1) / 2; ++i) {
        if (!trace.insert(to_bool(state)).second) {
            return false;
        }
        const auto output = net(state);
        vector<std::pair<double, unsigned>> ranked;
        for (unsigned j = 0; j < output.size(); ++j) {
            ranked.emplace_back(-output[j], j);
        }
        std::sort(ranked.begin(), ranked.end());
        for (const auto& [_, a] : ranked) {
            if (a % 2) {
                unsigned b = (a - 1) / 2;
                if (state[b]) {
                    if (b + 1 < state.size()) {
                        return true;
                    }
                    state[b] = 0.;
                    break;
                }
            } else {
                unsigned b = a / 2;
                if (!state[b]) {
                    if (b % 2) {
                        return true;
                    }
                    state[b] = 1;
                    break;
                }
            }
        }
    }
    return false;
}

[[maybe_unused]]
bool flips_incorrect_bits(
    const FeedForwardNeuralNetwork<>& net,
    vector<real_t> state,
    int n)
{
    unordered_set<vector<bool>> trace;
    auto to_bool = [](const auto& state) {
        vector<bool> vec(state.size());
        std::transform(
            state.begin(),
            state.end(),
            vec.begin(),
            [](const real_t val) { return val > 0.; });
        return vec;
    };
    for (unsigned i = n; i < (state.size() + 1) / 2; ++i) {
        if (!trace.insert(to_bool(state)).second) {
            return false;
        }
        const auto output = net(state);
        auto it = std::max_element(output.begin(), output.end());
        unsigned a = std::distance(output.begin(), it);
        if (a % 2) {
            unsigned b = (a - 1) / 2;
            if (state[b]) {
                if (b + 1 < state.size()) {
                    return true;
                }
                state[b] = 0.;
            } else {
                return false;
            }
        } else {
            unsigned b = a / 2;
            if (!state[b]) {
                if (b % 2) {
                    return true;
                }
                state[b] = 1;
                break;
            } else {
                return false;
            }
        }
    }
    return false;
}

template <typename Cond>
bool flips_incorrect_bits2(
    const FeedForwardNeuralNetwork<>& net,
    int bits,
    Cond&& cond)
{
    vector<real_t> state(bits, 0);
    if (cond(net, state, 0)) {
        return true;
    }
    for (int i = 0; i < bits; i += 2) {
        state[i] = 1;
        if (cond(net, state, 1)) {
            return true;
        }
        state[i] = 0;
    }
    return false;
}

[[maybe_unused]]
police::FeedForwardNeuralNetwork<> get_unsafe_net(int bits)
{
    static int cached = 0;
    static FeedForwardNeuralNetwork<> net;
    if (cached != bits) {
        cached = bits;
        net = get_net(bits, [&](const auto& net) {
            return flips_incorrect_bits2(net, bits, flips_incorrect_bits);
        });
    }
    return net;
}

[[maybe_unused]]
police::FeedForwardNeuralNetwork<> get_unsafe_net_filtered(int bits)
{
    static int cached = 0;
    static FeedForwardNeuralNetwork<> net;
    if (cached != bits) {
        cached = bits;
        net = get_net(bits, [&](const auto& net) {
            return flips_incorrect_bits2(
                net,
                bits,
                flips_incorrect_bits_filtered);
        });
    }
    return net;
}

[[maybe_unused]]
police::NeuralNetworkPolicy<> to_policy(FeedForwardNeuralNetwork<>&& net)
{
    vector<police::size_t> input(net.input.size());
    std::iota(input.begin(), input.end(), 0);
    vector<police::size_t> output(net.layers.back().size());
    std::iota(output.begin(), output.end(), 0);
    assert(input.size() * 2 == output.size());
    police::size_t num_actions = output.size();
    return {std::move(net), std::move(input), std::move(output), num_actions};
}

} // namespace

#define CREATE_TEST_INTERNAL(                                                  \
    TestName,                                                                  \
    Tags,                                                                      \
    GenNet,                                                                    \
    Bits,                                                                      \
    Factory,                                                                   \
    Verify,                                                                    \
    Params,                                                                    \
    AppFilter,                                                                 \
    Generalizer)                                                               \
    TEST_CASE(                                                                 \
        TestName ": " #Factory " - " #Generalizer,                             \
        "[ic3][binary_flip]" Tags)                                             \
    {                                                                          \
        auto model = binary_flip_model(Bits);                                  \
        auto init = binary_flip_initial_state(Bits);                           \
        auto goal = binary_flip_goal(Bits);                                    \
        auto not_goal = binary_flip_negated_goal(Bits);                        \
        auto avoid = binary_flip_avoid(Bits);                                  \
        auto policy = to_policy(GenNet(Bits));                                 \
                                                                               \
        std::shared_ptr<Z3SMTFactory> smt = std::make_shared<Z3SMTFactory>();  \
        auto nnlp = Factory();                                                 \
        auto sat = Params;                                                     \
                                                                               \
        auto gen = Generalizer(&model.variables, BFlipCheckAvoid());           \
                                                                               \
        auto start_smt = smt->make_shared();                                   \
        start_smt->add_variables(model.variables);                             \
        start_smt->add_constraint(init.as_expression());                       \
        auto start = epic3::StartGenerator(std::move(start_smt));              \
                                                                               \
        auto ic3 = epic3::ExplicitStateIC3smt<                                 \
            BFlipSuccessorGenerator,                                           \
            decltype(sat),                                                     \
            std::decay_t<decltype(gen)>>(                                      \
            &model.variables,                                                  \
            avoid.as_expression(),                                             \
            std::move(start),                                                  \
            BFlipSuccessorGenerator(&policy.get_nn(), Bits, AppFilter),        \
            std::move(sat),                                                    \
            std::move(gen),                                                    \
            false);                                                            \
                                                                               \
        auto path = ic3();                                                     \
        Verify(path);                                                          \
    }

#define CREATE_TESTS_INTERNAL(                                                 \
    Factory,                                                                   \
    Tags,                                                                      \
    U,                                                                         \
    S,                                                                         \
    Params,                                                                    \
    Suffix,                                                                    \
    AppFilter,                                                                 \
    Generalizer)                                                               \
    CREATE_TEST_INTERNAL(                                                      \
        "Binary flip model has a plan" Suffix,                                 \
        Tags,                                                                  \
        U,                                                                     \
        (BITS),                                                                \
        Factory,                                                               \
        [](const auto& path) { REQUIRE(path.has_value()); },                   \
        Params,                                                                \
        AppFilter,                                                             \
        Generalizer)                                                           \
                                                                               \
    CREATE_TEST_INTERNAL(                                                      \
        "Binary flip model doesn't have a plan" Suffix,                        \
        Tags,                                                                  \
        S,                                                                     \
        (BITS),                                                                \
        Factory,                                                               \
        [](const auto& path) {                                                 \
            if (path.has_value()) {                                            \
                for (const auto& [state, label] : path.value()) {              \
                    ranges::printer()(state)                                   \
                        << "\n action: " << label << std::endl;                \
                }                                                              \
            }                                                                  \
            REQUIRE(!path.has_value());                                        \
        },                                                                     \
        Params,                                                                \
        AppFilter,                                                             \
        Generalizer)

#define CREATE_TESTS(Factory, Tags, Generalizer)                               \
    CREATE_TESTS_INTERNAL(                                                     \
        Factory,                                                               \
        Tags,                                                                  \
        get_unsafe_net,                                                        \
        get_safe_net,                                                          \
        epic3::PIC3SatEdgeIndividual(                                          \
            epic3::PIC3SatParameters(smt, nnlp, false, true, false, false),    \
            model,                                                             \
            policy,                                                            \
            {not_goal},                                                        \
            avoid),                                                            \
        "",                                                                    \
        false,                                                                 \
        Generalizer)                                                           \
                                                                               \
    CREATE_TESTS_INTERNAL(                                                     \
        Factory,                                                               \
        Tags,                                                                  \
        get_unsafe_net,                                                        \
        get_safe_net,                                                          \
        epic3::PIC3SatEdgeIndividual(                                          \
            epic3::PIC3SatParameters(smt, nnlp, false, false, false, false),   \
            model,                                                             \
            policy,                                                            \
            {not_goal},                                                        \
            avoid),                                                            \
        " (no-determinize)",                                                   \
        false,                                                                 \
        Generalizer)                                                           \
                                                                               \
    CREATE_TESTS_INTERNAL(                                                     \
        Factory,                                                               \
        Tags,                                                                  \
        get_unsafe_net_filtered,                                               \
        get_safe_net_filtered,                                                 \
        epic3::PIC3SatEdgeIndividual(                                          \
            epic3::PIC3SatParameters(smt, nnlp, true, true, false, false),     \
            model,                                                             \
            policy,                                                            \
            {not_goal},                                                        \
            avoid),                                                            \
        " (app-filter)",                                                       \
        true,                                                                  \
        Generalizer)                                                           \
                                                                               \
    CREATE_TESTS_INTERNAL(                                                     \
        Factory,                                                               \
        Tags,                                                                  \
        get_unsafe_net_filtered,                                               \
        get_safe_net_filtered,                                                 \
        epic3::PIC3SatEdgeIndividual(                                          \
            epic3::PIC3SatParameters(smt, nnlp, true, false, false, false),    \
            model,                                                             \
            policy,                                                            \
            {not_goal},                                                        \
            avoid),                                                            \
        " (app-filter, no-determinize)",                                       \
        true,                                                                  \
        Generalizer)

#define CREATE_TESTS_UC(Factory, Tags)                                         \
    CREATE_TESTS(Factory, Tags, epic3::UnsatCoreGeneralizer)

#define CREATE_TESTS_G(Factory, Tags)                                          \
    CREATE_TESTS(Factory, Tags, epic3::GreedyGeneralizer)

#define CREATE_TESTS_x(Factory, Tags)                                          \
    CREATE_TESTS_UC(Factory, Tags) CREATE_TESTS_G(Factory, Tags)

#define CREATE_TESTS_SINGLETON(Factory, Tags)                                  \
    CREATE_TESTS_INTERNAL(                                                     \
        Factory,                                                               \
        Tags "[ic3sat-singleton]",                                             \
        get_unsafe_net,                                                        \
        get_safe_net,                                                          \
        epic3::PIC3SatSingleton(                                               \
            epic3::PIC3SatSingletonParameters(smt, nnlp, false, false),        \
            model,                                                             \
            policy,                                                            \
            {not_goal},                                                        \
            avoid),                                                            \
        " - singleton",                                                        \
        false,                                                                 \
        epic3::UnsatCoreGeneralizer)                                           \
                                                                               \
    CREATE_TESTS_INTERNAL(                                                     \
        Factory,                                                               \
        Tags "[ic3sat-singleton]",                                             \
        get_unsafe_net_filtered,                                               \
        get_safe_net_filtered,                                                 \
        epic3::PIC3SatSingleton(                                               \
            epic3::PIC3SatSingletonParameters(smt, nnlp, true, false),         \
            model,                                                             \
            policy,                                                            \
            {not_goal},                                                        \
            avoid),                                                            \
        " - singleton (app-filter)",                                           \
        true,                                                                  \
        epic3::UnsatCoreGeneralizer)

#if POLICE_Z3
namespace {
std::shared_ptr<NNLPSMTFactory> Z3_NNLP()
{
    static std::shared_ptr<Z3SMTFactory> z3 = std::make_shared<Z3SMTFactory>();
    return std::make_shared<NNLPSMTFactory>(z3.get(), false);
}

std::shared_ptr<NNLPLPFactory> Z3_NNLP_LP()
{
    static Z3LPFactory z3;
    return std::make_shared<NNLPLPFactory>(&z3, false);
}
} // namespace

CREATE_TESTS_x(Z3_NNLP, "[smt][z3]") CREATE_TESTS_x(Z3_NNLP_LP, "[lp][z3]")

#if POLICE_GUROBI
    namespace
{
    std::shared_ptr<NNLPLPFactory> Gurobi_NNLP()
    {
        static GurobiLPFactory gurobi;
        return std::make_shared<NNLPLPFactory>(&gurobi, false);
    }

#if POLICE_MARABOU
    std::shared_ptr<NNLPLPFactory> Gurobi_NNLP_Preprocessed()
    {
        static GurobiLPFactory gurobi;
        return std::make_shared<NNLPLPFactory>(&gurobi, true);
    }
#endif
} // namespace

CREATE_TESTS_x(Gurobi_NNLP, "[lp][gurobi]")
    CREATE_TESTS_SINGLETON(Gurobi_NNLP, "[lp][gurobi]")

#if POLICE_MARABOU
        CREATE_TESTS_G(Gurobi_NNLP_Preprocessed, "[lp][gurobi][marabou]")
#endif
#endif

#if POLICE_MARABOU
            namespace
{
    std::shared_ptr<NNLPFactory> Marabou_NNLP()
    {
        static MarabouLPFactory factory;
        return std::make_shared<NNLPBranchNBoundFactory>(&factory);
    }
} // namespace

CREATE_TESTS_G(Marabou_NNLP, "[marabou]");
#endif

#endif

#undef CREATE_TESTS_SINGLETON
#undef CREATE_TESTS
#undef CREATE_TESTS_G
#undef CREATE_TESTS_UC
#undef CREATE_TESTS_x
#undef CREATE_TESTS_INTERNAL
#undef CREATE_TEST_INTERNAL
