#pragma once

#include "police/storage/bounded_number_bit_compression.hpp"
#include "police/storage/compressed_state.hpp"
#include "police/storage/flat_state.hpp"
#include "police/storage/segmented_arrays.hpp"
#include "police/storage/unordered_set.hpp"
#include "police/utils/hash.hpp"

#include <algorithm>
#include <cstdint>

namespace police {

template <typename Bin = std::uint32_t>
struct explicit_state_container {
private:
    using _state_data_t = segmented_arrays<Bin>;
    using _compressor_t = bounded_range_bit_compressor<Bin>;
    using _size_t = typename _state_data_t::size_t;

    struct _hash;
    struct _compare;
    using _hash_set = unordered_set<_size_t, _hash, _compare>;

public:
    using state_t = compressed_state<_compressor_t>;
    using size_t = _size_t;

    template <typename Iterator>
    constexpr explicit_state_container(
        std::size_t num_floats,
        Iterator int_begin,
        Iterator int_end)
        : compression_(num_floats, int_begin, int_end)
        , state_data_(compression_.num_bins())
        , id_set_(
              1u << 16u,
              _hash{&state_data_, compression_.num_bins()},
              _compare{&state_data_, compression_.num_bins()})
    {
    }

    explicit_state_container(const explicit_state_container&) = delete;

    explicit_state_container(explicit_state_container&&) = delete;

    [[nodiscard]]
    constexpr state_t operator[](size_t index) const
    {
        return state_t(state_data_[index], &compression_);
    }

    constexpr std::pair<size_t, bool> insert(const flat_state& state)
    {
        auto idx = state_data_.allocate();
        auto* bin = state_data_[idx];
        assert(
            std::all_of(bin, bin + compression_.num_bins(), [](const auto bin) {
                return bin == 0u;
            }));
        store(bin, state);
        return insert_id_or_deallocate(idx);
    }

    constexpr size_t size() const { return id_set_.size(); }

private:
    struct _hash {
        [[nodiscard]]
        constexpr std::size_t operator()(_size_t index) const
        {
            return police::hash_array(data->at(index), num_bins);
        }

        const _state_data_t* data;
        std::size_t num_bins;
    };

    struct _compare {
        [[nodiscard]]
        constexpr bool operator()(_size_t p, _size_t q) const
        {
            const auto* p_bin = data->at(p);
            const auto* q_bin = data->at(q);
            return std::equal(p_bin, p_bin + num_bins, q_bin);
        }

        const _state_data_t* data;
        std::size_t num_bins;
    };

    constexpr void store(Bin* bin, const flat_state& state) const
    {
        for (auto i = 0u; i < compression_.num_floats(); ++i) {
            compression_.set_float(bin, i, static_cast<real_t>(state[i]));
        }
        for (auto i = compression_.num_floats(); i < compression_.size(); ++i) {
            compression_.set_int(bin, i, static_cast<int_t>(state[i]));
        }
    }

    constexpr std::pair<size_t, bool> insert_id_or_deallocate(size_t index)
    {
        auto pr = id_set_.insert(index);
        if (pr.second) {
            return {index, true};
        } else {
            state_data_.deallocate(index);
            return {*pr.first, false};
        }
    }

    _compressor_t compression_;
    _state_data_t state_data_;
    _hash_set id_set_;
};

} // namespace police
