#include "police/base_types.hpp"
#include "police/cg_relaxation.hpp"
#include "police/compute_graph_factory.hpp"
#include "police/layer_bounds.hpp"
#include "police/storage/ffnn.hpp"

#include <catch2/catch.hpp>
#include <iostream>

namespace {
using namespace police;

LayerBounds<real_t> analyze(
    const FeedForwardNeuralNetwork<>& net,
    const LayerBounds<real_t>& bounds)
{
    auto cg = cg::from_ffnn(net);
    cg::ComputeGraphRelaxation relax(
        cg,
        std::make_shared<cg::NodeRelaxerFactory>());
    auto result = relax.compute_layer_bounds(bounds);
    std::cout << "in: " << bounds << "\n";
    for (auto l = 1u; l < result.size(); ++l) {
        std::cout << "L" << l << ": " << result[l] << "\n";
    }
    std::cout << std::flush;
    return result.back();
}

FeedForwardNeuralNetwork<> network_1to1_2l()
{
    FeedForwardNeuralNetwork<> net;
    net.layers.resize(2);

    auto& l0 = net.layers[0];
    l0.biases.resize(1, 0);
    l0.weights.resize(1, {1.});

    auto& l1 = net.layers[1];
    l1.biases.resize(1, 1);
    l1.weights.resize(1, {1.});

    return net;
}

FeedForwardNeuralNetwork<> network_2to1_2l()
{
    FeedForwardNeuralNetwork<> net;
    net.layers.resize(2);

    auto& l0 = net.layers[0];
    l0.biases.resize(2, 0);
    l0.weights.resize(2);
    l0.weights[0] = {1., 2.};
    l0.weights[1] = {2., 1.};

    auto& l1 = net.layers[1];
    l1.biases.resize(1, 0);
    l1.weights.resize(1, {1., -1.});

    return net;
}

FeedForwardNeuralNetwork<> network_3l_neurips()
{
    FeedForwardNeuralNetwork<> net;
    net.layers.resize(3);

    auto& l0 = net.layers[0];
    l0.biases.resize(2, 0);
    l0.weights.resize(2);
    l0.weights[0] = {2., 1.};
    l0.weights[1] = {-3., 4.};

    auto& l1 = net.layers[1];
    l1.biases.resize(2, 0);
    l1.weights.resize(2);
    l1.weights[0] = {4, -2.};
    l1.weights[1] = {2., 1.};

    auto& l2 = net.layers[2];
    l2.biases.resize(1, 0);
    l2.weights.resize(1);
    l2.weights[0] = {-2., 1.};

    return net;
}

} // namespace

TEST_CASE("Test CG LIPA 2L 1x1 inactive", "[lipa][cg]")
{
    auto net = network_1to1_2l();
    LayerBounds<real_t> in(1);
    in.set_bounds(0, 0, 1);
    auto out = analyze(net, in);
    std::cout << out.lb(0) << " <= o <= " << out.ub(0) << std::endl;
    REQUIRE(out.lb(0) >= 0);
}

TEST_CASE("Test CG LIPA 2L 1x1 active", "[lipa][cg]")
{
    auto net = network_1to1_2l();
    LayerBounds<real_t> in(1);
    in.set_bounds(0, -1, 0);
    auto out = analyze(net, in);
    std::cout << out.lb(0) << " <= o <= " << out.ub(0) << std::endl;
    REQUIRE(out.lb(0) >= 0);
}

TEST_CASE("Test CG LIPA 2L 1x1 unstable", "[lipa][cg]")
{
    auto net = network_1to1_2l();
    LayerBounds<real_t> in(1);
    in.set_bounds(0, -1, 1);
    auto out = analyze(net, in);
    std::cout << out.lb(0) << " <= o <= " << out.ub(0) << std::endl;
    REQUIRE(out.lb(0) >= 0);
}

TEST_CASE("Test CG LIPA 2L 2x1", "[lipa][cg]")
{
    auto net = network_2to1_2l();
    LayerBounds<real_t> in(2);
    in.set_bounds(0, -1, 1);
    in.set_bounds(1, -1, 1);
    auto out = analyze(net, in);
    std::cout << out.lb(0) << " <= o <= " << out.ub(0) << std::endl;
    REQUIRE(out.lb(0) >= -1.75);
}

TEST_CASE("Test CG LIPA 3L (NeurIPS example)", "[lipa][cg]")
{
    auto net = network_3l_neurips();
    LayerBounds<real_t> in(2);
    in.set_bounds(0, -2, 2);
    in.set_bounds(1, -1, 3);
    auto out = analyze(net, in);
    std::cout << out.lb(0) << " <= o <= " << out.ub(0) << std::endl;
    REQUIRE(out.ub(0) >= 24.);
    REQUIRE(out.lb(0) <= -36.);
}
