#include "police/verifiers/ic3/frame.hpp"
#include "police/base_types.hpp"
#include "police/storage/flat_state.hpp"
#include <algorithm>
#include <cstdint>

namespace police::ic3 {

CubeDatabase::SubsumptionCache::SubsumptionCache(
    const id_map<Cube>::container_type* cubes)
    : cache(cubes->size())
    , cubes(cubes)
{
}

[[nodiscard]]
bool CubeDatabase::SubsumptionCache::operator()(size_t i, size_t j) const
{
    assert(i < cubes->size() && j < cubes->size());
    assert(i != j);
    std::uint8_t subsumes = I_SUBSUMES_J;
    if (i > j) {
        std::swap(i, j);
        subsumes = J_SUBSUMES_I;
    }
    if (i >= cache.size()) {
        cache.resize(cubes->size());
    }
    if (j - i - 1 >= cache[i].size()) {
        cache[i].resize(cubes->size() - i - 1, NOT_CACHED);
    }
    auto& status = cache[i][j - i - 1];
    if (status == NOT_CACHED) {
        const auto& c_i = cubes->at(i);
        const auto& c_j = cubes->at(j);
        assert(c_i != c_j);
        if (c_i <= c_j) {
            status = I_SUBSUMES_J;
        } else if (c_j <= c_i) {
            status = J_SUBSUMES_I;
        } else {
            status = INCOMPARABLE;
        }
    }
    return status == subsumes;
}

void CubeDatabase::SubsumptionCache::scrub(const vector<size_t>& idxs)
{
    for (auto i = 0u; i < idxs.size(); ++i) {
        const size_t old_i = idxs[i];
        if (old_i >= cache.size()) {
            // clear remaining cache entries
            cache.resize(i);
            break;
        }
        auto& i_cache = cache[old_i];
        if (!i_cache.empty()) {
            for (auto j = i + 1; j < idxs.size(); ++j) {
                const size_t new_j = j - i - 1;
                const size_t old_j = idxs[j] - old_i - 1;
                if (old_j >= i_cache.size()) {
                    // clear remaining cache entries
                    i_cache.resize(j - i - 1);
                    break;
                }
                i_cache[new_j] = i_cache[old_j];
            }
            i_cache.resize(idxs.size() - i - 1, NOT_CACHED);
        }
        if (i != old_i) {
            cache[i] = std::move(i_cache);
        }
    }
    cache.resize(idxs.size());
}

CubeDatabase::CubeDatabase()
    : cubes_(std::make_unique<id_map<Cube>>())
    , subsumes_(cubes_->data())
{
}

void CubeDatabase::clear()
{
    frame_idx_.clear();
    dangling_cubes_.clear();
    cubes_.reset(new id_map<Cube>());
    subsumes_ = SubsumptionCache(cubes_->data());
}

CubeDatabase::iterator CubeDatabase::begin() const
{
    return cubes_->begin();
}

CubeDatabase::iterator CubeDatabase::end() const
{
    return cubes_->end();
}

size_t CubeDatabase::size() const
{
    return cubes_->size();
}

std::pair<CubeDatabase::iterator, bool> CubeDatabase::insert(Cube cube)
{
    auto result = cubes_->insert(std::move(cube));
    if (result.second) {
        frame_idx_.push_back(NO_FRAME);
    }
    assert(frame_idx_.size() == cubes_->data()->size());
    return result;
}

void CubeDatabase::remove(size_t cube_idx)
{
    dangling_cubes_.push_back(cube_idx);
    frame_idx_[cube_idx] = NO_FRAME;
}

const Cube& CubeDatabase::get_cube(size_t idx) const
{
    return cubes_->data()->at(idx);
}

size_t CubeDatabase::get_frame(size_t idx) const
{
    assert(idx < frame_idx_.size());
    return frame_idx_[idx];
}

void CubeDatabase::set_frame(size_t cube_idx, size_t frame_idx)
{
    frame_idx_[cube_idx] = frame_idx;
}

bool CubeDatabase::subsumes(size_t idx1, size_t idx2) const
{
    return subsumes_(idx1, idx2);
}

void CubeDatabase::filter_no_frame_cubes()
{
    size_t new_idx = 0;
    for (size_t old_idx = 0; old_idx < dangling_cubes_.size(); ++old_idx) {
        const size_t cube_id = dangling_cubes_[old_idx];
        // check if cube was reinsterted into some frame after its last removal
        if (frame_idx_[cube_id] == NO_FRAME) {
            dangling_cubes_[new_idx] = cube_id;
            ++new_idx;
        }
    }
    dangling_cubes_.resize(new_idx);
}

void CubeDatabase::scrub_frame_ids(const vector<size_t>& cubes)
{
    for (size_t i = 0; i < cubes.size(); ++i) {
        frame_idx_[i] = frame_idx_[cubes[i]];
    }
    frame_idx_.resize(cubes.size());
}

vector<size_t> CubeDatabase::collect_garbage()
{
    filter_no_frame_cubes();
    std::sort(dangling_cubes_.begin(), dangling_cubes_.end());
    auto ids = cubes_->destroy(dangling_cubes_.begin(), dangling_cubes_.end());
    dangling_cubes_.clear();
    subsumes_.scrub(ids);
    scrub_frame_ids(ids);
    return ids;
}

Frame::Frame(CubeDatabase* db, size_t frame_id)
    : db_(db)
    , frame_idx_(frame_id)
{
}

Frame::iterator Frame::insert(size_t cube_id)
{
    db_->set_frame(cube_id, frame_idx_);
    auto pos = std::lower_bound(cubes_.begin(), cubes_.end(), cube_id);
    assert(pos == cubes_.end() || *pos > cube_id);
    return cubes_.insert(pos, cube_id);
}

Frame::iterator Frame::erase(iterator pos)
{
    return cubes_.erase(pos);
}

void Frame::remove_if_subsuming(size_t cube_id)
{
    size_t i = 0;
    for (size_t j = 0; j < cubes_.size(); ++j) {
        if (cubes_[j] == cube_id || db_->subsumes(cubes_[j], cube_id)) {
            db_->remove(cubes_[j]);
        } else {
            cubes_[i] = cubes_[j];
            ++i;
        }
    }
    cubes_.resize(i);
}

void Frame::remove(size_t cube_id)
{
    auto pos = std::lower_bound(cubes_.begin(), cubes_.end(), cube_id);
    assert(pos != cubes_.end() && *pos == cube_id);
    cubes_.erase(pos);
}

Frame::const_iterator Frame::find(size_t cube_id) const
{
    return std::lower_bound(cubes_.begin(), cubes_.end(), cube_id);
}

bool Frame::contains_subsumed(size_t cube_id) const
{
    return std::any_of(cubes_.begin(), cubes_.end(), [&](auto&& cube2) {
        return cube_id == cube2 || db_->subsumes(cube_id, cube2);
    });
}

bool Frame::contains_subsumed(const Cube& cube) const
{
    return std::any_of(cubes_.begin(), cubes_.end(), [&](auto&& cube2) {
        return cube <= db_->get_cube(cube2);
    });
}

bool Frame::contains(const flat_state& vec) const
{
    return std::any_of(cubes_.begin(), cubes_.end(), [&](auto&& cube2) {
        return db_->get_cube(cube2).contains(vec);
    });
}

FramesInitializer::FramesInitializer(
    CubeDatabase* cube_db,
    segmented_vector<Frame>* frames)
    : cube_db_(cube_db)
    , frames_(frames)
{
}

void FramesInitializer::operator()()
{
    cube_db_->clear();
    frames_->clear();
    frames_->emplace_back(cube_db_, 1);
}

FramesChecker::FramesChecker(CubeDatabase*, segmented_vector<Frame>* frames)
    // : cube_db_(cube_db)
    : frames_(frames)
{
}

bool FramesChecker::operator()(const flat_state& state, size_t frame_idx) const
{
    assert(frame_idx > 0u);
    const auto& frames = *frames_;
    for (; frame_idx <= frames_->size(); ++frame_idx) {
        if (frames[frame_idx - 1].contains(state)) {
            return false;
        }
    }
    return true;
}

CubeInserter::CubeInserter(
    CubeDatabase* cube_db,
    segmented_vector<Frame>* frames)
    : cube_db_(cube_db)
    , frames_(frames)
{
}

size_t CubeInserter::operator()(Cube cube, size_t frame_idx)
{
    assert(frame_idx > 0u);
    const auto cube_id = cube_db_->insert(std::move(cube)).first->second;
#ifndef NDEBUG
    for (auto i = frame_idx - 1; i < frames_->size(); ++i) {
        assert(!(*frames_)[i].contains_subsumed(cube_id));
    }
#endif
    for (auto cur_frame = 0u; cur_frame < frame_idx; ++cur_frame) {
        (*frames_)[cur_frame].remove_if_subsuming(cube_id);
    }
    (*frames_)[frame_idx - 1].insert(cube_id);
    return cube_id;
}

} // namespace police::ic3
