#pragma once

#include "police/base_types.hpp"
#include "police/macros.hpp"
#include "police/model.hpp"
#include "police/storage/flat_state.hpp"
#include "police/storage/unordered_set.hpp"
#include "police/storage/variable_space.hpp"
#include "police/storage/vector.hpp"
#include "police/utils/hash.hpp"
#include "police/verifiers/ic3/cube.hpp"

#include <algorithm>
#include <cassert>
#include <limits>
#include <utility>

#define EXPENSIVE_ASSERT(x) POLICE_ASSERT(x)
#define EXPENSIVE_CHECKS 1

namespace police::ic3::syntactic {

class FramesStorage {
    struct CubeIdHasher {
        explicit CubeIdHasher(const std::vector<Cube>* cubes)
            : cubes(cubes)
        {
        }

        [[nodiscard]]
        std::size_t operator()(size_t id) const
        {
            return police::get_hash(cubes->at(id));
        }

        const std::vector<Cube>* cubes;
    };

    struct CubeIdEqual {
        explicit CubeIdEqual(const std::vector<Cube>* cubes)
            : cubes(cubes)
        {
        }

        [[nodiscard]]
        std::size_t operator()(size_t i, size_t j) const
        {
            return cubes->at(i) == cubes->at(j);
        }

        const std::vector<Cube>* cubes;
    };

    struct ValueInfo {
        using container = std::vector<std::pair<int, std::vector<size_t>>>;
        using const_iterator = container::const_iterator;

        container cube_refs;

        void clear() { cube_refs.clear(); }

        void replace_ids(const std::vector<size_t>& id_map)
        {
            size_t l = 0;
            for (size_t i = 0; i < cube_refs.size(); ++i) {
                auto& refs = cube_refs[i].second;
                size_t k = 0;
                for (size_t j = 0; j < refs.size(); ++j) {
                    const size_t new_id = id_map[refs[j]];
                    if (new_id != FramesStorage::DELETED) {
                        refs[k] = new_id;
                        ++k;
                    }
                }
                refs.erase(refs.begin() + k, refs.end());
                if (!refs.empty()) {
                    if (l < i) {
                        cube_refs[l] = std::move(cube_refs[i]);
                    }
                    ++l;
                }
            }
            cube_refs.erase(cube_refs.begin() + l, cube_refs.end());
        }

        [[nodiscard]]
        auto lb(int value)
        {
            return std::lower_bound(
                cube_refs.begin(),
                cube_refs.end(),
                value,
                [](const std::pair<int, std::vector<size_t>>& pr,
                   const int& val) { return pr.first < val; });
        }

        [[nodiscard]]
        auto ub(int value)
        {
            return std::upper_bound(
                cube_refs.begin(),
                cube_refs.end(),
                value,
                [](const int& val,
                   const std::pair<int, std::vector<size_t>>& pr) {
                    return pr.first > val;
                });
        }

        [[nodiscard]]
        auto rend()
        {
            return cube_refs.rend();
        }

        [[nodiscard]]
        auto rbegin()
        {
            return cube_refs.rbegin();
        }

        [[nodiscard]]
        const_iterator end()
        {
            return cube_refs.end();
        }

        [[nodiscard]]
        const_iterator begin()
        {
            return cube_refs.begin();
        }

        [[nodiscard]]
        const_iterator lb(int value) const
        {
            return std::lower_bound(
                cube_refs.begin(),
                cube_refs.end(),
                value,
                [](const std::pair<int, std::vector<size_t>>& pr,
                   const int& val) { return pr.first < val; });
        }

        [[nodiscard]]
        const_iterator ub(int value) const
        {
            return std::upper_bound(
                cube_refs.begin(),
                cube_refs.end(),
                value,
                [](const int& val,
                   const std::pair<int, std::vector<size_t>>& pr) {
                    return pr.first > val;
                });
        }

        [[nodiscard]]
        auto rend() const
        {
            return cube_refs.rend();
        }

        [[nodiscard]]
        auto rbegin() const
        {
            return cube_refs.rbegin();
        }

        [[nodiscard]]
        const_iterator end() const
        {
            return cube_refs.end();
        }

        [[nodiscard]]
        const_iterator begin() const
        {
            return cube_refs.begin();
        }

        void remove(int value, size_t cube_id)
        {
            auto pos = lb(value);
            POLICE_ASSERT(pos != end());
            auto elem = std::lower_bound(
                pos->second.begin(),
                pos->second.end(),
                cube_id);
            POLICE_ASSERT(elem != pos->second.end() && *elem == cube_id);
            pos->second.erase(elem);
            if (pos->second.empty()) {
                cube_refs.erase(pos);
            }
        }

        void insert(int value, size_t cube_id)
        {
            auto pos = lb(value);
            if (pos == end() || pos->first != value) {
                cube_refs.insert(
                    pos,
                    std::pair<int, std::vector<size_t>>(value, {cube_id}));
            } else {
                POLICE_ASSERT(!pos->second.empty());
                POLICE_ASSERT(cube_id > pos->second.back());
                pos->second.push_back(cube_id);
            }
        }
    };

    struct VarInfo {
        ValueInfo lb_to_cube;
        ValueInfo ub_to_cube;
    };

public:
    constexpr static size_t DELETED = 1 << (sizeof(size_t) * 8 - 1);

    struct Range {
        size_t begin_frame = 0;
        size_t end_frame = std::numeric_limits<size_t>::max();
        [[nodiscard]]
        bool contains(size_t f) const
        {
            return begin_frame <= f && f <= end_frame;
        }
    };

    template <typename Callback>
    struct _RangeConditionedCallback {
        _RangeConditionedCallback(
            const vector<size_t>* frame,
            Callback callback,
            Range range)
            : frame(frame)
            , callback(std::move(callback))
            , range(std::move(range))
        {
        }

        void operator()(size_t cube_id)
        {
            POLICE_ASSERT(frame->size() > cube_id);
            if (range.contains(frame->at(cube_id))) {
                callback(cube_id);
            }
        }

        const vector<size_t>* frame;
        Callback callback;
        Range range;
    };

    explicit FramesStorage(
        size_t num_vars,
        const vector<size_t>* frame_index = nullptr)
        : var_infos_(num_vars)
        , cube_ids_(10240, CubeIdHasher(&cubes_), CubeIdEqual(&cubes_))
        , frame_index_(frame_index)
    {
    }

    FramesStorage(const FramesStorage&) = delete;

    FramesStorage(FramesStorage&&) = delete;

    void set_frame_indexes(const vector<size_t>* frames)
    {
        frame_index_ = frames;
    }

    std::pair<size_t, bool> insert(Cube cube)
    {
        const size_t cube_id = cubes_.size();
        cubes_.push_back(std::move(cube));
        auto unique_id = cube_ids_.insert(cube_id);
        if (unique_id.second) {
            size_t n = 0;
            for (const auto& [var, vals] : cubes_.back()) {
                POLICE_ASSERT(var < var_infos_.size());
                auto& var_info = var_infos_[var];
                if (vals.has_lb()) {
                    const int lb = static_cast<int>(vals.lb);
                    var_info.lb_to_cube.insert(lb, cube_id);
                    ++n;
                }
                if (vals.has_ub()) {
                    const int ub = static_cast<int>(vals.ub);
                    var_info.ub_to_cube.insert(ub, cube_id);
                    ++n;
                }
            }
            POLICE_ASSERT(n > 0);
            left_.push_back(0);
            sizes_.push_back(n);
            return {cube_id, true};
        } else {
            cubes_.pop_back();
            return {*unique_id.first, false};
        }
    }

    void remove(size_t cube_id)
    {
        POLICE_ASSERT(cube_ids_.find(cube_id) != cube_ids_.end());
        cube_ids_.erase(cube_ids_.find(cube_id));
        sizes_[cube_id] |= DELETED;
        left_[cube_id] |= DELETED;
        orphans_.push_back(cube_id);
    }

    [[nodiscard]]
    size_t num_orphans() const
    {
        return orphans_.size();
    }

    [[nodiscard]]
    size_t num_cubes() const
    {
        return cubes_.size();
    }

    [[nodiscard]]
    bool is_fragmented() const
    {
        return !orphans_.empty();
    }

    [[nodiscard]]
    const Cube& operator[](size_t id) const
    {
        return cubes_[id];
    }

    [[nodiscard]]
    const Cube& at(size_t id) const
    {
        return cubes_[id];
    }

    [[nodiscard]]
    size_t get_frame(size_t id) const
    {
        return frame_index_->at(id);
    }

    [[nodiscard]]
    bool is_orphaned(size_t id) const
    {
        return sizes_[id] & DELETED;
    }

    void clear()
    {
        cubes_.clear();
        cube_ids_.clear();
        sizes_.clear();
        left_.clear();
        orphans_.clear();
        for (size_t var = 0; var < var_infos_.size(); ++var) {
            var_infos_[var].lb_to_cube.clear();
            var_infos_[var].ub_to_cube.clear();
        }
    }

    vector<size_t> defragment()
    {
        vector<size_t> old_id_list;
        old_id_list.reserve(cubes_.size());
        vector<size_t> id_remap(cubes_.size(), DELETED);
        size_t new_id = 0;
        for (size_t old_id = 0; old_id < cubes_.size(); ++old_id) {
            if (!is_orphaned(old_id)) {
                id_remap[old_id] = new_id;
                if (old_id != new_id) {
                    cubes_[new_id] = std::move(cubes_[old_id]);
                    sizes_[new_id] = sizes_[old_id];
                }
                old_id_list.push_back(old_id);
                ++new_id;
            }
        }
        cubes_.erase(cubes_.begin() + new_id, cubes_.end());
        sizes_.erase(sizes_.begin() + new_id, sizes_.end());
        left_.erase(left_.begin() + new_id, left_.end());
        cube_ids_.clear();
        for (size_t i = 0; i < cubes_.size(); ++i) {
            cube_ids_.insert(i);
        }
        for (size_t var = 0; var < var_infos_.size(); ++var) {
            auto& var_info = var_infos_[var];
            var_info.lb_to_cube.replace_ids(id_remap);
            var_info.ub_to_cube.replace_ids(id_remap);
        }
        return old_id_list;
    }

    void clean_orphan_references()
    {
        for (size_t i : orphans_) {
            for (const auto& [var, vals] : cubes_[i]) {
                auto& var_info = var_infos_[var];
                if (vals.has_lb()) {
                    var_info.lb_to_cube.remove(static_cast<int>(vals.lb), i);
                }
                if (vals.has_ub()) {
                    var_info.ub_to_cube.remove(static_cast<int>(vals.ub), i);
                }
            }
        }
        orphans_.clear();
    }

    [[nodiscard]]
    std::pair<size_t, size_t>
    get_frame_rejection_cube(const flat_state& state) const
    {
        std::pair<size_t, size_t> result(-1, 0);
        auto& [result_id, result_frame] = result;
        const auto& frames = *frame_index_;
        forall_subsumed(state, [&](size_t cube_id) {
            POLICE_ASSERT(cube_id < frames.size());
            const size_t cube_frame = frames[cube_id];
            if (cube_frame > result_frame ||
                (cube_frame == result_frame && cube_id < result_id)) {
                result_id = cube_id;
                result_frame = cube_frame;
            }
        });
        return result;
    }

    template <typename Callback>
    void forall_subsumed(
        const flat_state& state,
        Range frame_range,
        Callback callback) const
    {
        forall_subsumed(
            state,
            _RangeConditionedCallback(
                frame_index_,
                std::move(callback),
                std::move(frame_range)));
    }

    template <typename Callback>
    void forall_subsumed(const flat_state& state, Callback callback) const
    {
#ifdef EXPENSIVE_CHECKS
        vector<bool> is_subsumed(cubes_.size(), false);
#endif
        POLICE_ASSERT(state.size() >= var_infos_.size());
        left_.assign(sizes_.begin(), sizes_.end());
        for (size_t var = 0; var < var_infos_.size(); ++var) {
            POLICE_ASSERT(var < var_infos_.size());
            const auto& var_info = var_infos_[var];
            const int val = static_cast<int>(state[var]);
            for (auto ref_it = var_info.lb_to_cube.begin();
                 ref_it != var_info.lb_to_cube.end() && ref_it->first <= val;
                 ++ref_it) {
                for (const auto& cube_id : ref_it->second) {
                    if (--left_[cube_id] == 0) {
                        EXPENSIVE_ASSERT(cubes_[cube_id].contains(state));
                        callback(cube_id);
#ifdef EXPENSIVE_CHECKS
                        is_subsumed[cube_id] = true;
#endif
                    }
                }
            }
            for (auto ref_it = var_info.ub_to_cube.rbegin();
                 ref_it != var_info.ub_to_cube.rend() && ref_it->first >= val;
                 ++ref_it) {
                for (const auto& cube_id : ref_it->second) {
                    if (--left_[cube_id] == 0) {
                        EXPENSIVE_ASSERT(cubes_[cube_id].contains(state));
                        callback(cube_id);
#ifdef EXPENSIVE_CHECKS
                        is_subsumed[cube_id] = true;
#endif
                    }
                }
            }
        }
#ifdef EXPENSIVE_CHECKS
        for (size_t i = 0; i < is_subsumed.size(); ++i) {
            POLICE_ASSERT(
                is_orphaned(i) || is_subsumed[i] == cubes_[i].contains(state));
        }
#endif
    }

    template <typename Callback>
    void forall_subsumed(const Cube& cube, Range frame_range, Callback callback)
    {
        forall_subsumed(
            cube,
            _RangeConditionedCallback(
                frame_index_,
                std::move(callback),
                std::move(frame_range)));
    }

    template <typename Callback>
    void forall_subsumed(const Cube& cube, Callback callback) const
    {
        if (cube.empty()) {
            return;
        }
#ifdef EXPENSIVE_CHECKS
        vector<bool> is_subsumed(cubes_.size(), false);
#endif
        left_.assign(sizes_.begin(), sizes_.end());
        for (auto it = cube.begin(); it != cube.end(); ++it) {
            const auto& var = it->first;
            const auto& vals = it->second;
            POLICE_ASSERT(var < var_infos_.size());
            const auto& var_info = var_infos_[var];
            if (vals.has_lb()) {
                const int lb = static_cast<int>(vals.lb);
                for (auto ref_it = var_info.lb_to_cube.begin();
                     ref_it != var_info.lb_to_cube.end() && ref_it->first <= lb;
                     ++ref_it) {
                    for (const auto& cube_id : ref_it->second) {
                        if (--left_[cube_id] == 0) {
                            EXPENSIVE_ASSERT(cube <= cubes_[cube_id]);
                            callback(cube_id);
#ifdef EXPENSIVE_CHECKS
                            is_subsumed[cube_id] = true;
#endif
                        }
                    }
                }
            }
            if (vals.has_ub()) {
                const int ub = static_cast<int>(vals.ub);
                for (auto ref_it = var_info.ub_to_cube.rbegin();
                     ref_it != var_info.ub_to_cube.rend() &&
                     ref_it->first >= ub;
                     ++ref_it) {
                    for (const auto& cube_id : ref_it->second) {
                        if (--left_[cube_id] == 0) {
                            EXPENSIVE_ASSERT(cube <= cubes_[cube_id]);
                            callback(cube_id);
#ifdef EXPENSIVE_CHECKS
                            is_subsumed[cube_id] = true;
#endif
                        }
                    }
                }
            }
        }
#ifdef EXPENSIVE_CHECKS
        for (size_t i = 0; i < is_subsumed.size(); ++i) {
            POLICE_ASSERT(is_orphaned(i) || is_subsumed[i] == cubes_[i] >= cube);
        }
#endif
    }

    template <typename Callback>
    void
    forall_that_subsume(const Cube& cube, Range frame_range, Callback callback)
    {
        forall_that_subsume(
            cube,
            _RangeConditionedCallback(
                frame_index_,
                std::move(callback),
                std::move(frame_range)));
    }

    template <typename Callback>
    void forall_that_subsume(const Cube& cube, Callback callback) const
    {
        if (cube.empty()) {
            for (size_t cube_id = 0; cube_id < num_cubes(); ++cube_id) {
                if (!is_orphaned(cube_id)) {
                    callback(cube_id);
                }
            }
            return;
        }
#ifdef EXPENSIVE_CHECKS
        vector<bool> subsume(cubes_.size(), false);
#endif
        size_t n = 0;
        std::fill(left_.begin(), left_.end(), 0);
        auto it = cube.begin();
        for (; it + 1 != cube.end(); ++it) {
            const auto& var = it->first;
            const auto& vals = it->second;
            POLICE_ASSERT(var < var_infos_.size());
            const auto& var_info = var_infos_[var];
            POLICE_ASSERT(vals.has_ub() || vals.has_lb());
            if (vals.has_lb() && vals.has_ub()) {
                const int lb = static_cast<int>(vals.lb);
                const int ub = static_cast<int>(vals.ub);
                for (auto ref_it = var_info.lb_to_cube.lb(lb);
                     ref_it != var_info.lb_to_cube.end() && ref_it->first <= ub;
                     ++ref_it) {
                    for (const auto& cube_id : ref_it->second) {
                        ++left_[cube_id];
                    }
                }
                for (auto ref_it = var_info.ub_to_cube.lb(lb);
                     ref_it != var_info.ub_to_cube.end() && ref_it->first <= ub;
                     ++ref_it) {
                    for (const auto& cube_id : ref_it->second) {
                        ++left_[cube_id];
                    }
                }
                n += 2;
            } else if (vals.has_lb()) {
                const int lb = static_cast<int>(vals.lb);
                for (auto ref_it = var_info.lb_to_cube.lb(lb);
                     ref_it != var_info.lb_to_cube.end();
                     ++ref_it) {
                    for (const auto& cube_id : ref_it->second) {
                        ++left_[cube_id];
                    }
                }
                ++n;
            } else if (vals.has_ub()) {
                const int ub = static_cast<int>(vals.ub);
                for (auto ref_it = var_info.ub_to_cube.begin();
                     ref_it != var_info.ub_to_cube.end() && ref_it->first <= ub;
                     ++ref_it) {
                    for (const auto& cube_id : ref_it->second) {
                        ++left_[cube_id];
                    }
                }
                ++n;
            }
        }
        const auto& var = it->first;
        const auto& vals = it->second;
        POLICE_ASSERT(var < var_infos_.size());
        const auto& var_info = var_infos_[var];
        POLICE_ASSERT(vals.has_ub() || vals.has_lb());
        if (vals.has_lb() && vals.has_ub()) {
            const int lb = static_cast<int>(vals.lb);
            const int ub = static_cast<int>(vals.ub);
            for (auto ref_it = var_info.lb_to_cube.lb(lb);
                 ref_it != var_info.lb_to_cube.end() && ref_it->first <= ub;
                 ++ref_it) {
                for (const auto& cube_id : ref_it->second) {
                    ++left_[cube_id];
                }
            }
            n += 1;
            for (auto ref_it = var_info.ub_to_cube.lb(lb);
                 ref_it != var_info.ub_to_cube.end() && ref_it->first <= ub;
                 ++ref_it) {
                for (const auto& cube_id : ref_it->second) {
                    if (left_[cube_id] == n) {
                        EXPENSIVE_ASSERT(cube >= cubes_[cube_id]);
                        callback(cube_id);
#ifdef EXPENSIVE_CHECKS
                        subsume[cube_id] = true;
#endif
                    }
                }
            }
        } else if (vals.has_lb()) {
            const int lb = static_cast<int>(vals.lb);
            for (auto ref_it = var_info.lb_to_cube.lb(lb);
                 ref_it != var_info.lb_to_cube.end();
                 ++ref_it) {
                for (const auto& cube_id : ref_it->second) {
                    if (left_[cube_id] == n) {
                        EXPENSIVE_ASSERT(cube >= cubes_[cube_id]);
                        callback(cube_id);
#ifdef EXPENSIVE_CHECKS
                        subsume[cube_id] = true;
#endif
                    }
                }
            }
        } else if (vals.has_ub()) {
            const int ub = static_cast<int>(vals.ub);
            for (auto ref_it = var_info.ub_to_cube.begin();
                 ref_it != var_info.ub_to_cube.end() && ref_it->first <= ub;
                 ++ref_it) {
                for (const auto& cube_id : ref_it->second) {
                    if (left_[cube_id] == n) {
                        EXPENSIVE_ASSERT(cube >= cubes_[cube_id]);
                        callback(cube_id);
#ifdef EXPENSIVE_CHECKS
                        subsume[cube_id] = true;
#endif
                    }
                }
            }
        }
#ifdef EXPENSIVE_CHECKS
        for (size_t i = 0; i < subsume.size(); ++i) {
            POLICE_ASSERT(is_orphaned(i) || subsume[i] == cubes_[i] <= cube);
        }
#endif
    }

    template <typename Callback>
    void forall_cubes(Range frame_range, Callback callback)
    {
        forall_cubes(_RangeConditionedCallback(
            frame_index_,
            std::move(callback),
            std::move(frame_range)));
    }

    template <typename Callback>
    void forall_cubes(Callback callback)
    {
        for (size_t id = 0; id < cubes_.size(); ++id) {
            if (!is_orphaned(id)) {
                callback(id);
            }
        }
    }

    void dump_frames(std::ostream& out, const VariableSpace& variables) const;
    void json(std::ostream& out, const VariableSpace& variables) const;
    void json(std::string_view path, const VariableSpace& variables) const;
    void json(std::ostream& out, const Model& model) const;
    void json(std::string_view path, const Model& model) const;

private:
    vector<VarInfo> var_infos_;
    vector<Cube> cubes_;
    unordered_set<size_t, CubeIdHasher, CubeIdEqual> cube_ids_;
    vector<size_t> sizes_;
    mutable vector<size_t> left_;
    vector<size_t> orphans_;
    const vector<size_t>* frame_index_;
};

} // namespace police::ic3::syntactic

#undef EXPENSIVE_ASSERT
#undef EXPENSIVE_CHECKS
