#pragma once

#include "police/base_types.hpp"
#include "police/storage/segmented_vector.hpp"
#include "police/storage/unordered_set.hpp"
#include "police/utils/hash.hpp"

#include <algorithm>
#include <functional>
#include <iterator>
#include <optional>

namespace police {

template <
    typename T,
    typename Hash = police::hash<T>,
    typename Equal = std::equal_to<>>
class id_map {
public:
    using container_type = segmented_vector<T>;

private:
    struct HashIdx {
        HashIdx(const container_type* data, Hash hash)
            : data(data)
            , hash(std::move(hash))
        {
        }

        [[nodiscard]]
        std::size_t operator()(size_t idx) const
        {
            return hash(data->at(idx));
        }

        const container_type* data;
        Hash hash;
    };

    struct EqualIdx {
        EqualIdx(const container_type* data, Equal equal)
            : data(data)
            , equal(std::move(equal))
        {
        }

        [[nodiscard]]
        bool operator()(size_t i, size_t j) const
        {
            return equal(data->at(i), data->at(j));
        }

        const container_type* data;
        Equal equal;
    };

    using index_set_type = unordered_set<size_t, HashIdx, EqualIdx>;

    constexpr static size_t RESERVE_SET_SIZE = 1024;
    constexpr static double FRAGMENTATION_TOLERANCE = 1.2;

public:
    class iterator {
    public:
        using value_type = std::pair<const T&, size_t>;
        using difference_type = int_t;
        struct pointer_type {
            explicit pointer_type(value_type value)
                : value(std::move(value))
            {
            }

            [[nodiscard]]
            value_type* operator->()
            {
                return &value;
            }

            [[nodiscard]]
            const value_type* operator->() const
            {
                return &value;
            }

            value_type value;
        };

        iterator() = default;

        [[nodiscard]]
        bool operator==(const iterator& other) const
        {
            return it_ == other.it_;
        }

        iterator& operator++()
        {
            ++it_;
            return *this;
        }

        iterator operator++(int)
        {
            iterator temp(*this);
            ++*this;
            return temp;
        }

        [[nodiscard]]
        value_type operator*() const
        {
            return value_type(data_->at(*it_), *it_);
        }

        [[nodiscard]]
        pointer_type operator->() const
        {
            return pointer_type(**this);
        }

    private:
        friend class id_map;

        iterator(const container_type* data, index_set_type::const_iterator it)
            : data_(data)
            , it_(std::move(it))
        {
        }

        const container_type* data_ = nullptr;
        index_set_type::const_iterator it_{};
    };

    static_assert(std::forward_iterator<iterator>);

    explicit id_map(Hash hash = Hash(), Equal equal = Equal())
        : idx_set_(
              RESERVE_SET_SIZE,
              HashIdx(&data_, std::move(hash)),
              EqualIdx(&data_, std::move(equal)))
    {
    }

    id_map(const id_map& other)
        : data_(other.data_)
        , idx_set_(
              RESERVE_SET_SIZE,
              HashIdx(&data_, other.idx_set_.hash_function().hash),
              EqualIdx(&data_, other.idx_set_.key_eq().equal))
    {
        idx_set_.insert(other.idx_set_.begin(), other.idx_set_.end());
    }

    id_map(id_map&& other) = delete;

    [[nodiscard]]
    iterator begin() const
    {
        return {&data_, idx_set_.begin()};
    }

    [[nodiscard]]
    iterator end() const
    {
        return {&data_, idx_set_.end()};
    }

    [[nodiscard]]
    iterator cbegin() const
    {
        return {&data_, idx_set_.begin()};
    }

    [[nodiscard]]
    iterator cend() const
    {
        return {&data_, idx_set_.end()};
    }

    [[nodiscard]]
    size_t size() const
    {
        return idx_set_.size();
    }

    [[nodiscard]]
    bool empty() const
    {
        return idx_set_.empty();
    }

    iterator::value_type operator[](T t) { return *insert(std::move(t)).first; }

    void clear()
    {
        data_.clear();
        idx_set_.clear();
    }

    std::pair<iterator, bool> insert(T t)
    {
        const size_t idx = data_.size();
        data_.push_back(std::move(t));
        auto res = idx_set_.insert(idx);
        if (!res.second) {
            data_.pop_back();
        }
        return {iterator(&data_, std::move(res.first)), res.second};
    }

    template <typename... Args>
    std::pair<iterator, bool> emplace(Args&&... args)
    {
        return insert(T(std::forward<Args>(args)...));
    }

    iterator erase(iterator pos)
    {
        auto new_pos = idx_set_.erase(pos.it_);
        return {&data_, std::move(new_pos)};
    }

    iterator erase(iterator first, iterator last)
    {
        auto new_pos = idx_set_.erase(first.it_, last.it_);
        return {&data_, std::move(new_pos)};
    }

    template <typename Iterator, typename Sentinal>
    vector<size_t> destroy(Iterator first, Sentinal last)
    {
        vector<size_t> idxs(idx_set_.begin(), idx_set_.end());
        std::sort(idxs.begin(), idxs.end());
        idx_set_.clear();
        size_t new_idx = 0;
        for (const size_t& old_idx : idxs) {
            while (first != last && *first < old_idx) {
                ++first;
            }
            if (first != last && *first == old_idx) {
                continue;
            }
            if (new_idx != old_idx) {
                data_[new_idx] = std::move(data_[old_idx]);
            }
            idxs[new_idx] = old_idx;
            idx_set_.insert(new_idx);
            ++new_idx;
        }
        data_.erase(data_.begin() + new_idx, data_.end());
        idxs.resize(new_idx);
        return idxs;
    }

    void swap(id_map& other)
    {
        data_.swap(other.data_);
        idx_set_.swap(other.idx_set_);
    }

    [[nodiscard]]
    size_t count(T t) const
    {
        const size_t temp_idx = data_.size();
        data_.push_back(std::move(t));
        const size_t result = idx_set_.count(temp_idx);
        data_.pop_back();
        return result;
    }

    [[nodiscard]]
    bool contains(T t) const
    {
        return count(std::move(t)) > 0;
    }

    [[nodiscard]]
    iterator find(T t) const
    {
        const size_t temp_idx = data_.size();
        data_.push_back(std::move(t));
        auto pos = idx_set_.find(temp_idx);
        data_.pop_back();
        return {&data_, std::move(pos)};
    }

    [[nodiscard]]
    const container_type* data() const
    {
        return &data_;
    }

    void reserve(size_t count)
    {
        data_.reserve(count);
        idx_set_.reserve(count);
    }

    std::optional<vector<size_t>> defragment(bool force = false)
    {
        if (data_.size() == idx_set_.size() ||
            (!force && fragmentation_ratio() < fragmentation_tolerance())) {
            return std::nullopt;
        }
        vector<size_t> idxs(idx_set_.begin(), idx_set_.end());
        std::sort(idxs.begin(), idxs.end());
        idx_set_.clear();
        size_t new_idx = 0;
        for (const size_t& old_idx : idxs) {
            if (new_idx != old_idx) {
                data_[new_idx] = std::move(data_[old_idx]);
            }
            idx_set_.insert(new_idx);
            ++new_idx;
        }
        data_.erase(data_.begin() + new_idx, data_.end());
        return {std::move(idxs)};
    }

    [[nodiscard]]
    double fragmentation_ratio() const
    {
        return static_cast<double>(data_.size()) /
               static_cast<double>(idx_set_.size());
    }

    [[nodiscard]]
    double fragmentation_tolerance() const
    {
        return FRAGMENTATION_TOLERANCE;
    }

private:
    mutable container_type data_;
    index_set_type idx_set_;
};

} // namespace police
