#pragma once

#include "police/macros.hpp"
#include "police/storage/unordered_map.hpp"
#include "police/storage/vector.hpp"

#include <algorithm>
#include <cassert>
#include <iterator>
#include <optional>

#define DBG(x)
// #include "police/utils/io.hpp"
// #include <iostream>

namespace police {

template <typename T, typename Callback>
void product(const vector<vector<T>>& vec, Callback callback)
{
    vector<const T*> prod(vec.size(), nullptr);
    vector<int> pos(vec.size());
    for (auto i = 0u; i < vec.size(); ++i) {
        if (vec[i].empty()) {
            return;
        }
        pos[i] = vec[i].size() - 1;
        prod[i] = &(vec[i].back());
    }

    for (;;) {
        callback(prod);

        int i = vec.size() - 1;
        for (; i >= 0; --i) {
            if (--pos[i] >= 0) {
                prod[i] = &(vec[i][pos[i]]);
                break;
            }
        }

        if (i < 0) {
            break;
        }

        for (unsigned j = i + 1; j < vec.size(); ++j) {
            pos[j] = vec[j].size() - 1;
            prod[j] = &(vec[j].back());
        }
    }
}

template <typename T>
vector<vector<T>> product(const vector<vector<vector<T>>>& vec)
{
    vector<vector<T>> result;
    auto push_back = [&result](const vector<const vector<T>*>& elem) {
        vector<T> prod;
        for (const auto* ts : elem) {
            prod.insert(prod.end(), ts->begin(), ts->end());
        }
        result.push_back(std::move(prod));
    };
    product(vec, push_back);
    return result;
}

template <typename T, typename Predicate>
void inplace_remove_if(vector<T>& vec, Predicate predicate)
{
    size_t i = 0;
    for (size_t j = 0; j < vec.size(); ++j) {
        if (!predicate(vec[j])) {
            if (i != j) {
                vec[i] = std::move(vec[j]);
            }
            ++i;
        }
    }
    vec.erase(vec.begin() + i, vec.end());
}

namespace details {
template <typename T>
void delete_trivially_satisfied(vector<vector<vector<T>>>& options)
{
    size_t i = 0;
    for (size_t j = 0; j < options.size(); ++j) {
        if (std::all_of(
                options[j].begin(),
                options[j].end(),
                [](const auto& candidate) { return !candidate.empty(); })) {
            if (i != j) {
                std::swap(options[i], options[j]);
            }
            ++i;
        }
    }
    options.erase(options.begin() + i, options.end());
}

template <typename T>
struct GetIdHash {
    size_t operator()(const T& t)
    {
        auto it = ids.emplace(t, ids.size());
        return it.first->second;
    }

    unordered_map<T, size_t> ids;
};

template <typename T>
struct SetInfo {
    SetInfo(vector<T>&& elements, size_t option)
        : elements(std::move(elements))
        , option(option)
        , open(this->elements.size())
    {
    }

    vector<T> elements;
    size_t option = 0;
    size_t open = 0;
};

template <typename T>
struct ElementInfo {
    ElementInfo()
        : value(std::nullopt)
    {
    }

    explicit ElementInfo(T t)
        : value(std::move(t))
    {
    }

    std::optional<T> value;
    vector<size_t> member_of;
    std::pair<int, int> rank{0, 0};
    bool locked = false;
};

} // namespace details

template <typename T, typename GetId = details::GetIdHash<T>>
vector<T>
greedy_hitting_set(vector<vector<vector<T>>>& options, GetId get_id = GetId())
{
    assert(std::all_of(options.begin(), options.end(), [](const auto& o) {
        return !o.empty();
    }));

    details::delete_trivially_satisfied(options);

    const auto num_options = options.size();
    vector<details::SetInfo<T>> set_infos;
    vector<details::ElementInfo<T>> element_infos;
    vector<vector<size_t>> option_to_elements(num_options);

    auto get_set_info = [&](size_t set_id) -> details::SetInfo<T>& {
        POLICE_ASSERT(set_id < set_infos.size());
        return set_infos[set_id];
    };

    auto get_or_create_element_info = [&](size_t element_id,
                                          T t) -> details::ElementInfo<T>& {
        if (element_id >= element_infos.size()) {
            element_infos.resize(element_id + 1);
        }
        element_infos[element_id].value = std::move(t);
        return element_infos[element_id];
    };

    auto get_element_info = [&](size_t element_id) -> details::ElementInfo<T>& {
        POLICE_ASSERT(element_id < element_infos.size());
        POLICE_ASSERT(element_infos[element_id].value.has_value());
        return element_infos[element_id];
    };

    auto get_elements = [&](size_t option_id) -> vector<size_t>& {
        POLICE_ASSERT(option_id < option_to_elements.size());
        return option_to_elements[option_id];
    };

    vector<size_t> result;
    size_t togo = num_options;

    {
        vector<bool> counted;
        auto set_idx = 0u;
        for (auto opt = 0u; opt < options.size(); ++opt) {
            auto& candidates = options[opt];
            auto& elements = get_elements(opt);
            assert(!candidates.empty());
            DBG(std::cout << "candidates#" << opt << ": ";)
            if (candidates.size() == 1u) {
                auto& set = candidates.front();
                DBG(std::cout << "[ ";)
                for (const T& t : set) {
                    const size_t id = get_id(t);
                    DBG(std::cout << id << " ";)
                    auto& element_info = get_or_create_element_info(id, t);
                    if (!element_info.locked) {
                        element_info.locked = true;
                        result.push_back(id);
                    }
                }
                DBG(std::cout << "]";)
                --togo;
            } else {
                counted.clear();
                for (auto cand = 0u; cand < candidates.size(); ++cand) {
                    auto& set = candidates[cand];
                    DBG(std::cout << (cand > 0 ? "; " : "") << "[ ";)
                    for (const T& t : set) {
                        const size_t id = get_id(t);
                        DBG(std::cout << id << " ";)
                        if (id >= counted.size()) {
                            counted.resize(id + 1, false);
                        }
                        if (!counted[id]) {
                            elements.push_back(id);
                        }
                        auto& element_info = get_or_create_element_info(id, t);
                        element_info.rank.first += !counted[id];
                        ++element_info.rank.second;
                        element_info.member_of.push_back(set_idx);
                        counted[id] = 1;
                    }
                    DBG(std::cout << "]";)
                    set_infos.emplace_back(std::move(set), opt);
                    ++set_idx;
                }
            }
            DBG(std::cout << std::endl;)
            vector<vector<T>>().swap(candidates);
        }
        vector<vector<vector<T>>>().swap(options);
    }

    vector<int> sat_counter(num_options, 0);

    // greedily select variables until all options are hit
    {
        auto propagate = [&](size_t id) {
            details::ElementInfo<T>& element_info = get_element_info(id);
            // update handled counter and flags
            for (const auto& set_idx : element_info.member_of) {
                auto& set_info = get_set_info(set_idx);
                // check if set is completed
                if (--set_info.open == 0) {
                    // update handled status of the corresponding option
                    const auto opt_idx = set_info.option;
                    POLICE_ASSERT(opt_idx < sat_counter.size());
                    if (++sat_counter[opt_idx] == 1) {
                        --togo;
                        // update variable counter
                        for (const auto& element_id : get_elements(opt_idx)) {
                            --get_element_info(element_id).rank.first;
                        }
                    }
                }
            }
            // clear rank
            element_info.rank = {-1, -1};
        };

        for (size_t j = 0; j < result.size(); ++j) {
            propagate(result[j]);
        }

        while (togo > 0) {
            // greedily choose variable having highest count
            auto max_ranked = std::max_element(
                element_infos.begin(),
                element_infos.end(),
                [](const auto& a, const auto& b) { return a.rank < b.rank; });
            POLICE_ASSERT(
                max_ranked != element_infos.end() &&
                max_ranked->rank != std::pair(-1, -1));
            const auto element_id =
                std::distance(element_infos.begin(), max_ranked);
            result.push_back(element_id);
            propagate(element_id);
        }
    }

    // greedily remove selected variables as long as all options are kept
    // being hit
    {
        size_t i = 0;
        for (size_t j = 0; j < result.size(); ++j) {
            const size_t id = result[j];
            auto& element_info = get_element_info(id);
            if (element_info.locked) {
                result[i] = id;
                ++i;
                continue;
            }
            bool needed = false;
            // check if t is needed, i.e., whether an option exist all whose
            // set options currently hit will not be hit anylonger
            for (const auto& set_idx : element_info.member_of) {
                auto& set_info = get_set_info(set_idx);
                if (++set_info.open == 1) {
                    const auto opt_idx = set_info.option;
                    POLICE_ASSERT(opt_idx < sat_counter.size());
                    needed = --sat_counter[opt_idx] == 0 || needed;
                }
            }
            if (needed) {
                // revert changes
                for (const auto& set_idx : element_info.member_of) {
                    auto& set_info = get_set_info(set_idx);
                    if (--set_info.open == 0) {
                        const auto opt_idx = set_info.option;
                        ++sat_counter[opt_idx];
                    }
                }
                result[i] = id;
                ++i;
            }
        }
        result.erase(result.begin() + i, result.end());
    }

    DBG(std::cout << "=> " << print_sequence(result) << std::endl;)

    // map ids to T values
    vector<T> result_t;
    result_t.reserve(result.size());
    std::transform(
        result.begin(),
        result.end(),
        std::back_inserter(result_t),
        [&](size_t id) { return *get_element_info(id).value; });

    DBG(std::cout << "=*> " << print_sequence(result_t) << std::endl;)

    return result_t;
}

inline vector<size_t> topological_sort(const vector<vector<size_t>>& graph)
{
    struct NodeInfo {
        explicit NodeInfo(size_t i)
            : backlink(i)
        {
        }
        size_t backlink;
        bool onstack = false;
    };
    struct ExpansionInfo {
        explicit ExpansionInfo(size_t node)
            : node(node)
        {
        }
        size_t node;
        size_t successor = 0;
    };
    const size_t num_nodes = graph.size();
    vector<size_t> result(num_nodes);
    vector<NodeInfo> infos(num_nodes, NodeInfo(num_nodes));
    vector<ExpansionInfo> trace;
    vector<size_t> stack;
    size_t i = 0;
    for (size_t node = 0; node < graph.size(); ++node) {
        if (infos[node].backlink < num_nodes) {
            continue;
        }
        stack.push_back(node);
        trace.emplace_back(node);
        ExpansionInfo* expand = &trace.back();
        NodeInfo* info = &infos[expand->node];
        const vector<size_t>* successors = &graph[node];
        info->backlink = 0;
        info->onstack = true;
        for (;;) {
            assert(!trace.empty());
            assert(expand->node < graph.size());
            for (; expand->successor < successors->size();
                 ++expand->successor) {
                const size_t succ = (*successors)[expand->successor];
                NodeInfo& succ_info = infos[succ];
                if (succ_info.backlink == num_nodes) {
                    succ_info.backlink = stack.size();
                    succ_info.onstack = true;
                    stack.push_back(succ);
                    trace.emplace_back(succ);
                    expand = &trace.back();
                    info = &succ_info;
                    successors = &graph[succ];
                    continue;
                } else if (succ_info.onstack) {
                    info->backlink =
                        std::min(succ_info.backlink, info->backlink);
                }
            }
            if (stack[info->backlink] == node) {
                for (int k = stack.size() - 1;
                     k >= static_cast<int>(info->backlink);
                     --k) {
                    node = stack[k];
                    infos[node].onstack = false;
                    result[i++] = node;
                }
                stack.erase(stack.begin() + info->backlink, stack.end());
            }
            trace.pop_back();
            if (trace.empty()) {
                break;
            }
            expand = &trace.back();
            info = &infos[expand->node];
            successors = &graph[expand->node];
        }
    }
    return result;
}

} // namespace police

#undef DBG
