#pragma once

#include "police/storage/vector.hpp"
#include "police/utils/bit_operations.hpp"
#include "police/utils/type_traits.hpp"

#include <algorithm>
#include <cassert>
#include <cstdint>
#include <strings.h>

#if defined(__clang__)
#pragma clang diagnostic push
#pragma clang diagnostic ignored "-Warray-bounds"
#else
#if defined(__GNUC__)
#pragma GCC diagnostic push
#pragma GCC diagnostic ignored "-Warray-bounds"
#pragma GCC diagnostic ignored "-Wmaybe-uninitialized"
#endif
#endif

namespace police {

template <
    typename Bin = std::uint32_t,
    typename FloatType = double,
    typename IntType = std::int32_t>
class bounded_range_bit_compressor {
    static_assert(sizeof(FloatType) % sizeof(Bin) == 0);
    static_assert(sizeof(IntType) <= 4);

public:
    using bin_t = Bin;
    using size_t = std::uint32_t;
    using int_t = IntType;
    using float_t = FloatType;

    template <typename BoundIterator>
    constexpr bounded_range_bit_compressor(
        size_t num_floats,
        BoundIterator begin_int,
        BoundIterator end_int)
        : num_floats_(num_floats)
    {
        num_bins_ = get_array_index_for_float(num_floats_);
        std::uint32_t offset = 0;
        std::uint32_t total_width = sizeof(float_t) * 8 * num_floats_;
        for (; begin_int != end_int; ++begin_int) {
            const int_t lb = begin_int->lower_bound;
            const int_t ub = begin_int->upper_bound;
            assert(lb <= ub);
            const std::uint32_t represented_max =
                static_cast<std::uint32_t>(static_cast<std::int64_t>(ub) - lb);
            const std::uint32_t width =
                (represented_max != 0u ? police::bits::get_msb(represented_max)
                                       : 0u) +
                1;
            const auto bin_idx = num_bins_ - (offset > 0);
            infos_.emplace_back(bin_idx, offset, width, lb, ub);
            total_width += width;
            offset = total_width & bits::pattern<std::uint32_t>(
                                       bits::get_msb(sizeof(bin_t) * 8));
            num_bins_ = (total_width >> bits::get_msb(sizeof(bin_t) * 8)) +
                        (offset > 0);
        }
    }

    constexpr size_t num_floats() const { return num_floats_; }

    constexpr size_t num_bins() const { return num_bins_; }

    constexpr float_t get_float(const bin_t* bucket, size_t index) const
    {
        assert(index < num_floats_);
        auto bin = aligned_read<sizeof(float_t)>(
            bucket + get_array_index_for_float(index));
        return as_float(bin);
    }

    constexpr int_t get_int(const bin_t* bucket, size_t index) const
    {
        assert(index >= num_floats_ && index - num_floats_ < infos_.size());
        const auto& type_info = infos_[index - num_floats_];
        const auto bin = unaligned_read<sizeof(int_t)>(
            bucket + type_info.array_index,
            type_info.offset,
            type_info.width);
        const auto value =
            static_cast<int_t>(static_cast<std::int64_t>(bin) + type_info.lb);
        assert(value >= type_info.lb && value <= type_info.ub);
        return value;
    }

    constexpr void set_float(bin_t* bucket, size_t index, float_t value) const
    {
        assert(index < num_floats_);
        aligned_store<sizeof(float_t)>(
            bucket + get_array_index_for_float(index),
            as_block(value));
        assert(get_float(bucket, index) == value);
    }

    constexpr void set_int(bin_t* bucket, size_t index, int_t value) const
    {
        assert(index >= num_floats_ && index - num_floats_ < infos_.size());
        const auto& type_info = infos_[index - num_floats_];
        assert(value >= type_info.lb && value <= type_info.ub);
        const auto bin = static_cast<block_t<sizeof(int_t)>>(
            static_cast<std::int64_t>(value) - type_info.lb);
        unaligned_store<sizeof(int_t)>(
            bucket + type_info.array_index,
            bin,
            type_info.offset,
            type_info.width);
        assert(get_int(bucket, index) == value);
    }

    constexpr size_t size() const { return num_floats_ + infos_.size(); }

private:
    template <size_t W>
    using block_t = police::ite_t<
        W == 1,
        std::uint8_t,
        police::ite_t<
            W == 2,
            std::uint16_t,
            police::ite_t<
                W == 4,
                std::uint32_t,
                police::ite_t<W == 8, std::uint64_t, void>>>>;

    template <size_t W>
    constexpr static block_t<W> ones = ~static_cast<block_t<W>>(0);

    using float_block_t = block_t<sizeof(float_t)>;

    union Float {
        float_t number;
        float_block_t binary;
    };

    constexpr static float_t as_float(float_block_t block)
    {
        Float f;
        f.binary = block;
        return f.number;
    }

    constexpr static float_block_t as_block(float_t val)
    {
        Float f;
        f.number = val;
        return f.binary;
    }

    template <size_t Bytes>
    constexpr static block_t<Bytes> aligned_read(const bin_t* bucket)
    {
        static_assert(Bytes < sizeof(bin_t) || Bytes % sizeof(bin_t) == 0);
        if constexpr (sizeof(bin_t) >= Bytes) {
            return static_cast<block_t<Bytes>>(*bucket);
        } else {
            return (static_cast<block_t<Bytes>>(aligned_read<Bytes / 2>(bucket))
                    << (Bytes * 4)) |
                   (static_cast<block_t<Bytes>>(aligned_read<Bytes / 2>(
                       bucket + Bytes / (2 * sizeof(bin_t)))));
        }
    }

    template <size_t Bytes>
    constexpr static void aligned_store(bin_t* dest, block_t<Bytes> block)
    {
        static_assert(Bytes < sizeof(bin_t) || Bytes % sizeof(bin_t) == 0);
        if constexpr (sizeof(bin_t) >= Bytes) {
            *dest = static_cast<bin_t>(block);
        } else {
            aligned_store<Bytes / 2>(
                dest,
                static_cast<block_t<Bytes / 2>>(block >> (Bytes * 4)));
            aligned_store<Bytes / 2>(
                dest + Bytes / (2 * sizeof(bin_t)),
                static_cast<block_t<Bytes / 2>>(block));
        }
    }

    template <size_t Bytes>
    constexpr static block_t<Bytes>
    unaligned_read(const bin_t* bucket, size_t offset, size_t width)
    {
        assert(Bytes * 8 >= width);
        assert(offset < sizeof(bin_t) * 8);
        if (offset + width <= 8 * sizeof(bin_t)) {
            assert(Bytes <= sizeof(bin_t));
            return (static_cast<block_t<Bytes>>(*bucket) >>
                    static_cast<block_t<Bytes>>(offset)) &
                   bits::pattern<block_t<Bytes>>(width);
        } else {
            block_t<Bytes> result = static_cast<block_t<Bytes>>(*bucket) >>
                                    static_cast<block_t<Bytes>>(offset);
            size_t read = sizeof(bin_t) * 8 - offset;
            width -= read;
            ++bucket;
            if constexpr (Bytes > sizeof(bin_t)) {
                while (width > 8 * sizeof(bin_t)) {
                    const auto bin = static_cast<block_t<Bytes>>(*bucket);
                    result = result | (bin << read);
                    width -= 8 * sizeof(bin_t);
                    read += 8 * sizeof(bin_t);
                    ++bucket;
                }
            }
            if (width) {
                assert(width <= sizeof(bin_t) * 8);
                assert(read < Bytes * 8);
                const auto bin = static_cast<block_t<Bytes>>(*bucket);
                result = result |
                         ((bits::pattern<block_t<Bytes>>(width) & bin) << read);
            }
            return result;
        }
    }

    template <size_t Bytes>
    constexpr static void unaligned_store(
        bin_t* dest,
        block_t<Bytes> block,
        size_t offset,
        size_t width)
    {
        assert(width <= sizeof(block_t<Bytes>) * 8);
        assert((~bits::pattern<block_t<Bytes>>(width) & block) == 0u);
        if (offset + width <= 8 * sizeof(bin_t)) {
            assert(Bytes <= sizeof(bin_t));
            assert(width == Bytes * 8 || !(block >> width));
            *dest = (*dest & ~bits::pattern<bin_t>(width, offset)) |
                    (static_cast<bin_t>(block) << static_cast<bin_t>(offset));
        } else {
            *dest = (*dest & bits::pattern<bin_t>(offset)) |
                    (static_cast<bin_t>(
                        block << static_cast<block_t<Bytes>>(offset)));
            block >>= static_cast<block_t<Bytes>>(sizeof(bin_t) * 8 - offset);
            width -= sizeof(bin_t) * 8 - offset;
            ++dest;
            if constexpr (Bytes > sizeof(bin_t)) {
                while (width > sizeof(bin_t) * 8) {
                    *dest = static_cast<bin_t>(block);
                    width -= sizeof(bin_t) * 8;
                    block >>= static_cast<block_t<Bytes>>(sizeof(bin_t) * 8);
                    ++dest;
                }
            }
            assert(width <= sizeof(bin_t) * 8);
            assert(width == 0 || (~bits::pattern<bin_t>(width) & block) == 0u);
            if (width > 0) {
                *dest = (*dest & ~bits::pattern<bin_t>(width)) |
                        static_cast<bin_t>(block);
            }
        }
    }

    constexpr static size_t get_array_index_for_float(size_t idx)
    {
        return std::max(sizeof(float_t) / sizeof(bin_t), 1ul) * idx;
    }

    struct type_infos {
        std::uint32_t array_index;
        std::uint32_t offset;
        std::uint32_t width;
        IntType lb;
        IntType ub;
    };

    vector<type_infos> infos_;
    size_t num_floats_ = 0;
    size_t num_bins_ = 0;
};

} // namespace police

#if defined(__clang__)
#pragma clang diagnostic pop
#else
#if defined(__GNUC__)
#pragma GCC diagnostic pop
#endif
#endif
