#include "police/verifiers/ic3/syntactic/abstraction.hpp"

#include <algorithm>

namespace police::ic3::syntactic {

SyntacticAbstraction::SyntacticAbstraction(size_t num_vars)
    : post_conditions_(num_vars)
{
}

void SyntacticAbstraction::notify_new_node()
{
    node_arcs_.emplace_back();
    hmax_.push_back(0);
    as_successor_.emplace_back();
}

void SyntacticAbstraction::add_arc(size_t src, HyperArc arc, Cube post)
{
    arc.arc_id = num_successors_.size();
    num_successors_.emplace_back(arc.successors.size());
    source_.emplace_back(src);
    for (const auto& succ : arc.successors) {
        as_successor_[succ].push_back(arc.arc_id);
    }
    auto [post_id, inserted] = post_conditions_.insert(std::move(post));
    if (inserted) {
        post_to_arcs_.emplace_back();
    }
    arc.post_id = post_id;
    post_to_arcs_[post_id].push_back(arc.arc_id);
    node_arcs_[src].push_back(arc.arc_id);
    arcs_.push_back(std::move(arc));
}

void SyntacticAbstraction::add_successor(const HyperArc& arc, size_t succ)
{
    as_successor_[succ].push_back(arc.arc_id);
    ++num_successors_[arc.arc_id];
    const_cast<HyperArc&>(arc).successors.push_back(succ);
}

void SyntacticAbstraction::apply_id_remap(const vector<size_t>& new_ids)
{
#ifndef NDEBUG
    vector<size_t> old_hmax(hmax_);
#endif

    // Remove dangling nodes and relocate node ids getting rid of any
    // fragmentation
    constexpr size_t PRUNED = -1;
    vector<size_t> id_map(as_successor_.size(), PRUNED);
    for (size_t j = 0; j < new_ids.size(); ++j) {
        size_t k = new_ids[j];
        assert(k >= j);
        id_map[k] = j;
        if (j != k) {
            node_arcs_[j] = std::move(node_arcs_[k]);
            hmax_[j] = hmax_[k];
        }
    }
    node_arcs_.erase(node_arcs_.begin() + new_ids.size(), node_arcs_.end());
    hmax_.erase(hmax_.begin() + new_ids.size(), hmax_.end());

    // Update arcs, removing the ones sourced in nodes removed in the
    // previous step. Again relocate arc ids to get rid of any
    // fragmentation.
    as_successor_.clear();
    as_successor_.resize(new_ids.size());
    num_successors_.clear();
    source_.clear();
    vector<HyperArc> new_arcs;
    new_arcs.reserve(arcs_.size());
    vector<size_t> new_arc_ids(arcs_.size(), PRUNED);
    size_t new_id = 0;
    for (size_t i = 0; i < node_arcs_.size(); ++i) {
        for (auto& arc_id : node_arcs_[i]) {
            new_arcs.push_back(std::move(arcs_[arc_id]));
            new_arc_ids[arc_id] = new_id;

            HyperArc& arc = new_arcs.back();
            arc.arc_id = new_id;
            arc_id = new_id;

#ifndef NDEBUG
            size_t old_h = 0;
            for (const size_t old_id : arc.successors) {
                old_h = std::max(old_h, old_hmax[old_id]);
            }
#endif

            size_t k = 0;
            for (size_t j = 0; j < arc.successors.size(); ++j) {
                size_t l = arc.successors[j];
                if (id_map[l] != PRUNED) {
                    arc.successors[k] = id_map[l];
                    as_successor_[id_map[l]].push_back(new_id);
                    ++k;
                }
            }
            arc.successors.erase(
                arc.successors.begin() + k,
                arc.successors.end());

            num_successors_.push_back(k);
            source_.push_back(i);

            ++new_id;

#ifndef NDEBUG
            size_t cur_h = 0;
            for (const size_t new_id : arc.successors) {
                cur_h = std::max(cur_h, hmax_[new_id]);
            }
            assert(cur_h >= old_h);
#endif
        }
    }
    arcs_.swap(new_arcs);

    // delete references to deleted arcs from post-to-arc mapping
    for (size_t post_id = 0; post_id < post_to_arcs_.size(); ++post_id) {
        auto& arcs = post_to_arcs_[post_id];
        if (arcs.empty()) {
            continue;
        }
        size_t i = 0;
        for (size_t j = 0; j < arcs.size(); ++j) {
            const size_t old_id = arcs[j];
            const size_t new_id = new_arc_ids[old_id];
            if (new_id != PRUNED) {
                arcs[i] = new_id;
                ++i;
            }
        }
        arcs.erase(arcs.begin() + i, arcs.end());
        if (arcs.empty()) {
            post_conditions_.remove(post_id);
            clean_post_conditions_ = true;
        }
    }
    if (clean_post_conditions_) {
        post_conditions_.clean_orphan_references();
        clean_post_conditions_ = false;
    }
}

const Cube& SyntacticAbstraction::get_post_condition(const HyperArc& arc)
{
    return post_conditions_.at(arc.post_id);
}

void SyntacticAbstraction::recompute_hmax()
{
    std::fill(hmax_.begin(), hmax_.end(), INF_FRAME);
    vector<size_t> queue;
    vector<size_t> open(num_successors_);
    for (size_t i = 0; i < open.size(); ++i) {
        if (open[i] == 0u) {
            auto src = source_[i];
            if (hmax_[src] != 1u) {
                queue.push_back(src);
                hmax_[src] = 1u;
            }
        }
    }
    for (size_t i = 0; i < queue.size(); ++i) {
        const size_t h = hmax_[queue[i]] + 1;
        const auto& ts = as_successor_[queue[i]];
        for (const auto& arc : ts) {
            if (--open[arc] == 0u) {
                auto src = source_[arc];
                if (hmax_[src] == INF_FRAME) {
                    hmax_[src] = h;
                    queue.push_back(src);
                }
            }
        }
    }
}

void SyntacticAbstraction::mark_arc_pruned(size_t arc_id)
{
    ++num_successors_[arc_id];
    auto& arc = arcs_[arc_id];
    auto& post_id = arc.post_id;
    if (post_id < post_to_arcs_.size()) {
        auto& post_arcs = post_to_arcs_[post_id];
        auto entry =
            std::lower_bound(post_arcs.begin(), post_arcs.end(), arc_id);
        assert(entry != post_arcs.end() && *entry == arc_id);
        post_arcs.erase(entry);
        if (post_arcs.empty()) {
            post_conditions_.remove(post_id);
            clean_post_conditions_ = true;
        }
        post_id = -1;
    }
}

size_t SyntacticAbstraction::recompute_h_value(size_t node)
{
    const auto& arcs = node_arcs_[node];
    size_t new_value = INF_FRAME;
    for (const auto& arc_id : arcs) {
        const auto& arc = arcs_[arc_id];
        size_t arc_value = 0;
        for (const auto& succ : arc.successors) {
            arc_value = std::max(arc_value, hmax_[succ]);
        }
        new_value = std::min(new_value, arc_value);
    }
    if (new_value != INF_FRAME) {
        ++new_value;
    }
    hmax_[node] = new_value;
    return new_value;
}

size_t SyntacticAbstraction::get_h_value(size_t node) const
{
    return hmax_[node];
}

void SyntacticAbstraction::set_h_value(size_t node, size_t h_value)
{
    hmax_[node] = h_value;
}

} // namespace police::ic3::syntactic
