#pragma once

#include "police/storage/segmented_vector.hpp"
#include "police/storage/unordered_set.hpp"
#include "police/verifiers/search/concepts.hpp"
#include "police/verifiers/search/node_id.hpp"

#include <limits>
#include <memory>
#include <sstream>

namespace police {

template <
    typename StateType,
    search::successor_generator<StateType> SuccessorGenerator,
    search::id_to_state_map<StateType> IdToStateMap,
    search::state_to_id_map<StateType> StateIdMap,
    search::state_predicate<StateType> Predicate,
    typename LabelToStr>
struct graphviz {
    graphviz(
        SuccessorGenerator succ_gen = SuccessorGenerator(),
        IdToStateMap lookup_state = IdToStateMap(),
        StateIdMap get_state_id = StateIdMap(),
        Predicate goal = Predicate(),
        LabelToStr label_to_str = LabelToStr(),
        size_t max_states = std::numeric_limits<size_t>::max())
        : out(std::make_shared<std::ostringstream>())
        , closed(std::make_shared<unordered_set<search::NodeId>>())
        , queue(std::make_shared<segmented_vector<search::NodeId>>())
        , succ_gen_(std::move(succ_gen))
        , lookup_state_(std::move(lookup_state))
        , get_state_id_(std::move(get_state_id))
        , predicate_(std::move(goal))
        , label_to_str_(std::move(label_to_str))
        , max_states_(max_states)
    {
    }

    bool operator()(const StateType& state)
    {
        const auto state_id = get_state_id_(state);
        if (closed->size() >= max_states_) {
            return false;
        }
        if (!closed->insert(state_id).second) return true;
        if (new_node(state)) {
            return true;
        }
        queue->push_back(state_id);
        size_t transition = 0;
        for (auto state_id = queue->back();
             closed->size() < max_states_ && !queue->empty();
             state_id = queue->back(), queue->pop_back()) {
            auto state = lookup_state_(state_id);
            auto successors = succ_gen_(state);
            for (auto it = successors.begin(); it != successors.end(); ++it) {
                auto&& [succ, label] = *it;
                auto succ_id = get_state_id_(succ);
                if (closed->insert(succ_id).second && !new_node(succ)) {
                    queue->push_back(succ_id);
                }
                *out << "t" << transition << " [label=\""
                     << label_to_str_(label) << "\",shape=box];\n";
                *out << "s" << state_id << " -> t" << transition << ";\n";
                *out << "t" << transition << " -> s" << succ_id << ";\n";
                ++transition;
            }
        }
        return !queue->empty();
    }

    bool new_node(const StateType& state) const
    {
        const auto state_id = get_state_id_(state);
        const bool goal = predicate_(state);
        *out << "s" << state_id << (goal ? " [peripheries=2]" : "") << ";"
             << "\n";
        return goal;
    }

    friend std::ostream& operator<<(std::ostream& out, const graphviz& gv)
    {
        out << "digraph G {\n";
        out << gv.out->view();
        out << "}\n" << std::endl;
        return out;
    }

    std::shared_ptr<std::ostringstream> out;
    std::shared_ptr<unordered_set<search::NodeId>> closed;
    std::shared_ptr<segmented_vector<search::NodeId>> queue;

    SuccessorGenerator succ_gen_;
    IdToStateMap lookup_state_;
    StateIdMap get_state_id_;
    Predicate predicate_;
    LabelToStr label_to_str_;
    size_t max_states_;
};

template <
    typename StateType,
    search::successor_generator<StateType> SuccessorGenerator,
    search::id_to_state_map<StateType> IdToStateMap,
    search::state_to_id_map<StateType> StateIdMap,
    search::state_predicate<StateType> Predicate,
    typename LabelToStr>
graphviz<
    StateType,
    SuccessorGenerator,
    IdToStateMap,
    StateIdMap,
    Predicate,
    LabelToStr>
create_graphviz(
    SuccessorGenerator succ_gen = SuccessorGenerator(),
    IdToStateMap lookup_state = IdToStateMap(),
    StateIdMap get_state_id = StateIdMap(),
    Predicate predicate = Predicate(),
    LabelToStr label_to_str = LabelToStr(),
    size_t max_states = std::numeric_limits<size_t>::max())
{
    return {
        std::move(succ_gen),
        std::move(lookup_state),
        std::move(get_state_id),
        std::move(predicate),
        std::move(label_to_str),
        max_states};
}

} // namespace police
