#include "police/storage/bounded_number_bit_compression.hpp"

#include <algorithm>
#include <catch2/catch.hpp>
#include <limits>
#include <ranges>
#include <vector>

namespace {
using namespace police;

constexpr int INTMIN = std::numeric_limits<int>::min();
constexpr int INTMAX = std::numeric_limits<int>::max();

struct Bounds {
    int lower_bound = 0;
    int upper_bound = 0;
};

using bin_t = bounded_range_bit_compressor<>::bin_t;

void require_equal(const bounded_range_bit_compressor<>&, bin_t*, std::size_t)
{
}

template <typename T, typename... C>
void require_equal(
    const bounded_range_bit_compressor<>& compressor,
    bin_t* bin,
    std::size_t idx,
    T&& c,
    C&&... d)
{
    if constexpr (std::is_same_v<T, double>) {
        REQUIRE(compressor.get_float(bin, idx) == c);
    } else {
        REQUIRE(compressor.get_int(bin, idx) == c);
    }
    require_equal(compressor, bin, idx + 1, std::forward<C>(d)...);
}

} // namespace

TEST_CASE("Bounded number bit compression: one bool", "[storage][compression]")
{
    std::vector<Bounds> bounds({Bounds{0, 1}});
    bounded_range_bit_compressor<> compressor(0, bounds.begin(), bounds.end());

    REQUIRE(compressor.size() == 1);
    REQUIRE(compressor.num_bins() == 1);

    SECTION("Set & get 0")
    {
        bounded_range_bit_compressor<>::bin_t a{};
        compressor.set_int(&a, 0, 0);
        REQUIRE(compressor.get_int(&a, 0) == 0);
    }

    SECTION("Set & get 1")
    {
        bounded_range_bit_compressor<>::bin_t a{};
        compressor.set_int(&a, 0, 1);
        REQUIRE(compressor.get_int(&a, 0) == 1);
    }
}

TEST_CASE(
    "Bounded number bit compression: one integer",
    "[storage][compression]")
{
    std::vector<Bounds> bounds({Bounds{-1, 1}});
    bounded_range_bit_compressor<> compressor(0, bounds.begin(), bounds.end());

    REQUIRE(compressor.size() == 1);
    REQUIRE(compressor.num_bins() == 1);

    SECTION("Set & get 0")
    {
        bounded_range_bit_compressor<>::bin_t a{};
        compressor.set_int(&a, 0, 0);
        REQUIRE(compressor.get_int(&a, 0) == 0);
    }

    SECTION("Set & get 1")
    {
        bounded_range_bit_compressor<>::bin_t a{};
        compressor.set_int(&a, 0, 1);
        REQUIRE(compressor.get_int(&a, 0) == 1);
    }

    SECTION("Set & get -1")
    {
        bounded_range_bit_compressor<>::bin_t a{};
        compressor.set_int(&a, 0, -1);
        REQUIRE(compressor.get_int(&a, 0) == -1);
    }

    SECTION("Double set & get")
    {
        bounded_range_bit_compressor<>::bin_t a{};
        compressor.set_int(&a, 0, -1);
        compressor.set_int(&a, 0, 1);
        REQUIRE(compressor.get_int(&a, 0) == 1);
    }

    SECTION("Double get")
    {
        bounded_range_bit_compressor<>::bin_t a{};
        compressor.set_int(&a, 0, 1);
        REQUIRE(compressor.get_int(&a, 0) == 1);
        REQUIRE(compressor.get_int(&a, 0) == 1);
    }
}

TEST_CASE(
    "Bounded number bit compression: one int limit",
    "[storage][compression]")
{
    const int min = std::numeric_limits<int>::min();
    const int max = std::numeric_limits<int>::max();
    const std::vector<Bounds> bounds({Bounds{min, max}});
    bounded_range_bit_compressor<> compressor(0, bounds.begin(), bounds.end());

    REQUIRE(compressor.size() == 1);
    REQUIRE(compressor.num_bins() == 1);

    SECTION("Set & get 0")
    {
        bounded_range_bit_compressor<>::bin_t a[1] = {0};
        compressor.set_int(a, 0, 0);
        REQUIRE(compressor.get_int(a, 0) == 0);
    }

    SECTION("Set & get INTMAX")
    {
        bounded_range_bit_compressor<>::bin_t a[1] = {0};
        compressor.set_int(a, 0, max);
        REQUIRE(compressor.get_int(a, 0) == max);
    }

    SECTION("Set & get INTMIN")
    {
        bounded_range_bit_compressor<>::bin_t a[1] = {0};
        compressor.set_int(a, 0, min);
        REQUIRE(compressor.get_int(a, 0) == min);
    }
}

TEST_CASE(
    "Bounded number bit compression: multiple ints, one bin",
    "[storage][compression]")
{
    const std::vector<Bounds> bounds({
        Bounds{0, 1},
        Bounds{-1, 1},
        Bounds{-10, 10},
    });
    bounded_range_bit_compressor<> compressor(0, bounds.begin(), bounds.end());
    bounded_range_bit_compressor<>::bin_t a[] = {0};

    REQUIRE(compressor.size() == 3);
    REQUIRE(compressor.num_bins() == 1);

    SECTION("Set & get 0")
    {
        compressor.set_int(a, 0, 0);
        compressor.set_int(a, 1, 0);
        compressor.set_int(a, 2, 0);
        require_equal(compressor, a, 0, 0, 0, 0);
    }

    SECTION("Set & get min")
    {
        compressor.set_int(a, 0, 0);
        compressor.set_int(a, 1, -1);
        compressor.set_int(a, 2, -10);
        require_equal(compressor, a, 0, 0, -1, -10);
    }

    SECTION("Set & get max")
    {
        compressor.set_int(a, 0, 1);
        compressor.set_int(a, 1, 1);
        compressor.set_int(a, 2, 10);
        require_equal(compressor, a, 0, 1, 1, 10);
    }

    SECTION("Interleaved set & get")
    {
        compressor.set_int(a, 0, 0);
        compressor.set_int(a, 1, 0);
        compressor.set_int(a, 2, 0);

        compressor.set_int(a, 0, 1);
        require_equal(compressor, a, 0, 1, 0, 0);

        compressor.set_int(a, 2, 1);
        require_equal(compressor, a, 0, 1, 0, 1);

        compressor.set_int(a, 0, 0);
        require_equal(compressor, a, 0, 0, 0, 1);

        compressor.set_int(a, 1, -1);
        require_equal(compressor, a, 0, 0, -1, 1);
    }
}

TEST_CASE(
    "Bounded number bit compression: multiple ints, two bins",
    "[storage][compression]")
{
    const std::vector<Bounds> bounds({
        Bounds{0, 1},
        Bounds{INTMIN, INTMAX},
        Bounds{-10, 10},
    });
    bounded_range_bit_compressor<> compressor(0, bounds.begin(), bounds.end());
    bounded_range_bit_compressor<>::bin_t a[] = {0, 0};

    REQUIRE(compressor.size() == 3);
    REQUIRE(compressor.num_bins() == 2);

    SECTION("Set & get 0")
    {
        compressor.set_int(a, 0, 0);
        compressor.set_int(a, 1, 0);
        compressor.set_int(a, 2, 0);
        require_equal(compressor, a, 0, 0, 0, 0);
    }

    SECTION("Set & get min")
    {
        compressor.set_int(a, 0, 0);
        compressor.set_int(a, 1, INTMIN);
        compressor.set_int(a, 2, -10);
        require_equal(compressor, a, 0, 0, INTMIN, -10);
    }

    SECTION("Set & get max")
    {
        compressor.set_int(a, 0, 1);
        compressor.set_int(a, 1, INTMAX);
        compressor.set_int(a, 2, 10);
        require_equal(compressor, a, 0, 1, INTMAX, 10);
    }

    SECTION("Interleaved set & get")
    {
        compressor.set_int(a, 0, 0);
        compressor.set_int(a, 1, 0);
        compressor.set_int(a, 2, 0);

        compressor.set_int(a, 0, 1);
        require_equal(compressor, a, 0, 1, 0, 0);

        compressor.set_int(a, 2, 1);
        require_equal(compressor, a, 0, 1, 0, 1);

        compressor.set_int(a, 0, 0);
        require_equal(compressor, a, 0, 0, 0, 1);

        compressor.set_int(a, 1, -1);
        require_equal(compressor, a, 0, 0, -1, 1);
    }
}

TEST_CASE(
    "Bounded number bit compression: multiple ints",
    "[storage][compression]")
{
    const std::vector<Bounds> bounds({
        Bounds{0, 1},
        Bounds{0, 4},
        Bounds{0, 4},
        Bounds{0, 4},
        Bounds{0, 4},
        Bounds{0, 4},
        Bounds{0, 4},
        Bounds{0, 4},
        Bounds{0, 4},
        Bounds{0, 4},

        Bounds{0, 40},
        Bounds{0, 40},
        Bounds{0, 40},
        Bounds{0, 40},
        Bounds{0, 40},

        Bounds{0, 0},
    });
    bounded_range_bit_compressor<> compressor(0, bounds.begin(), bounds.end());
    bounded_range_bit_compressor<>::bin_t a[] = {0, 0};

    REQUIRE(compressor.size() == 16);
    REQUIRE(compressor.num_bins() == 2);

    SECTION("Set 0")
    {
        for (auto idx = 0u; idx < 16; ++idx) compressor.set_int(a, idx, 0);
        REQUIRE(std::ranges::all_of(
            std::ranges::iota_view{0, 16},
            [&](auto idx) { return compressor.get_int(a, idx) == 0; }));
    }

    SECTION("Random")
    {
        compressor.set_int(a, 0, 0);
        compressor.set_int(a, 1, 0);
        compressor.set_int(a, 2, 0);
        compressor.set_int(a, 3, 0);
        compressor.set_int(a, 4, 4);
        compressor.set_int(a, 5, 4);
        compressor.set_int(a, 6, 0);
        compressor.set_int(a, 7, 0);
        compressor.set_int(a, 8, 0);
        compressor.set_int(a, 9, 0);
        compressor.set_int(a, 10, 11);
        compressor.set_int(a, 11, 0);
        compressor.set_int(a, 12, 0);
        compressor.set_int(a, 13, 0);
        compressor.set_int(a, 14, 10);
        compressor.set_int(a, 15, 0);
        REQUIRE(compressor.get_int(a, 0) == 0);
        REQUIRE(compressor.get_int(a, 1) == 0);
        REQUIRE(compressor.get_int(a, 2) == 0);
        REQUIRE(compressor.get_int(a, 3) == 0);
        REQUIRE(compressor.get_int(a, 4) == 4);
        REQUIRE(compressor.get_int(a, 5) == 4);
        REQUIRE(compressor.get_int(a, 6) == 0);
        REQUIRE(compressor.get_int(a, 7) == 0);
        REQUIRE(compressor.get_int(a, 8) == 0);
        REQUIRE(compressor.get_int(a, 9) == 0);
        REQUIRE(compressor.get_int(a, 10) == 11);
        REQUIRE(compressor.get_int(a, 11) == 0);
        REQUIRE(compressor.get_int(a, 12) == 0);
        REQUIRE(compressor.get_int(a, 13) == 0);
        REQUIRE(compressor.get_int(a, 14) == 10);
        REQUIRE(compressor.get_int(a, 15) == 0);
    }
}

TEST_CASE("Bounded number bit compression: float", "[storage][compression]")
{
    const std::vector<Bounds> bounds({});
    bounded_range_bit_compressor<> compressor(1, bounds.begin(), bounds.end());
    bounded_range_bit_compressor<>::bin_t a[] = {0, 0};

    REQUIRE(compressor.size() == 1);
    REQUIRE(compressor.num_bins() == 2);

    SECTION("Set & get 0")
    {
        compressor.set_float(a, 0, 0);
        REQUIRE(compressor.get_float(a, 0) == 0.0);
    }

    SECTION("Set & get 2")
    {
        compressor.set_float(a, 0, 2.0);
        REQUIRE(compressor.get_float(a, 0) == 2.0);
    }

    SECTION("Set & get 8.125, 4.0, 4.0, 32.5")
    {
        compressor.set_float(a, 0, 8.125);
        REQUIRE(compressor.get_float(a, 0) == 8.125);
        compressor.set_float(a, 0, 4.0);
        REQUIRE(compressor.get_float(a, 0) == 4.0);
        compressor.set_float(a, 0, 4.0);
        REQUIRE(compressor.get_float(a, 0) == 4.0);
        compressor.set_float(a, 0, 32.5);
        REQUIRE(compressor.get_float(a, 0) == 32.5);
    }
}

TEST_CASE(
    "Bounded number bit compression: float + int",
    "[storage][compression]")
{
    const std::vector<Bounds> bounds({Bounds{-100, 100}});
    bounded_range_bit_compressor<> compressor(1, bounds.begin(), bounds.end());
    bounded_range_bit_compressor<>::bin_t a[] = {0, 0, 0};

    REQUIRE(compressor.size() == 2);
    REQUIRE(compressor.num_bins() == 3);

    SECTION("Set & get 0")
    {
        compressor.set_float(a, 0, 0);
        compressor.set_int(a, 1, 0);
        require_equal(compressor, a, 0, 0.0, 0);
    }

    SECTION("Set & get 2")
    {
        compressor.set_float(a, 0, 2.0);
        compressor.set_int(a, 1, 2);
        require_equal(compressor, a, 0, 2.0, 2);
    }

    SECTION("Set & get interleaved")
    {
        compressor.set_float(a, 0, 2.0);
        compressor.set_int(a, 1, 2);
        compressor.set_int(a, 1, -10);
        require_equal(compressor, a, 0, 2.0, -10);
        compressor.set_float(a, 0, 4.0);
        require_equal(compressor, a, 0, 4.0, -10);
    }
}
