#pragma once

#include "police/expressions/expression.hpp"
#include "police/expressions/expression_evaluator.hpp"
#include "police/macros.hpp"
#include "police/storage/flat_state.hpp"
#include "police/storage/path.hpp"
#include "police/storage/variable_space.hpp"
#include "police/verifiers/ic3/concepts.hpp"
#include "police/verifiers/ic3/statistics.hpp"
#include "police/verifiers/search/concepts.hpp"

#include <algorithm>
#include <optional>

#if POLICE_VERBOSITY
#include "police/utils/value_seq_printer.hpp"
#endif

namespace police::ic3 {

namespace _detail {
struct Obligation {
    Obligation(flat_state state, size_t frame)
        : state(std::move(state))
        , frame(frame)
    {
    }

    flat_state state;
    size_t frame;
    size_t label;
};

struct AllPathsAcceptor {
    std::optional<size_t> operator()(const Path&) const { return std::nullopt; }
};

} // namespace _detail

template <
    concepts::StartStateGenerator _StartGenerator,
    search::successor_generator<flat_state> _SuccessorGenerator,
    concepts::FramesInitializer _FramesInitializer,
    concepts::FrameAdder _FrameAdder,
    concepts::FramesChecker _FramesChecker,
    concepts::FramesRefiner _FramesRefiner,
    concepts::PathChecker _PathChecker = _detail::AllPathsAcceptor>
class ExplicitStateIC3 {
    using obligation_type = _detail::Obligation;

public:
    ExplicitStateIC3(
        const VariableSpace* vspace,
        _StartGenerator start_generator,
        _SuccessorGenerator successor_generator,
        _FramesInitializer init_frames,
        _FrameAdder add_frame,
        _FramesChecker frames,
        _FramesRefiner frame_refiner,
        expressions::Expression avoid,
        bool rescheduling)
        : ExplicitStateIC3(
              vspace,
              std::move(start_generator),
              std::move(successor_generator),
              std::move(init_frames),
              std::move(add_frame),
              std::move(frames),
              std::move(frame_refiner),
              _PathChecker(),
              std::move(avoid),
              rescheduling)
    {
    }

    ExplicitStateIC3(
        const VariableSpace* vspace,
        _StartGenerator start_generator,
        _SuccessorGenerator successor_generator,
        _FramesInitializer init_frames,
        _FrameAdder add_frame,
        _FramesChecker frames,
        _FramesRefiner frame_refiner,
        _PathChecker accept_path,
        expressions::Expression avoid,
        bool rescheduling)
        : vspace_(vspace)
        , generate_start_(std::move(start_generator))
        , generate_successors_(std::move(successor_generator))
        , init_frames_(std::move(init_frames))
        , add_frame_(std::move(add_frame))
        , is_in_frame_(std::move(frames))
        , refine_frames_(std::move(frame_refiner))
        , accept_path_(std::move(accept_path))
        , avoid_(std::move(avoid))
        , rescheduling_enabled_(rescheduling)
    {
    }

    [[nodiscard]]
    std::optional<Path> operator()()
    {
        stats_ = Statistics();
        stats_.frames = 1;
        stats_.total_time.resume();
        cur_frame_ = 1;
        init_frames_();
        for (;;) {
            POLICE_DEBUG_MSG("new trial\n")
            std::optional<flat_state> state =
                stats_.start_generation_time.function_call(generate_start_);
            if (!state.has_value()) {
#if POLICE_VERBOSITY
                if (last_printed_ != stats_.clauses) {
                    stats_.print_status_line(std::cout);
                }
                std::cout << "ic3: proceeding to frame " << (cur_frame_ + 1)
                          << std::endl;
                last_printed_ = stats_.clauses;
#endif
                ++cur_frame_;
                ++stats_.frames;
                const bool found_invariant =
                    stats_.frame_construction_time.function_call(add_frame_);
                if (found_invariant) {
                    std::cout
                        << "ic3: found an invariant. Property not reachable."
                        << std::endl;
                    stats_.total_time.stop();
                    return std::nullopt;
                }
                continue;
            }
            assert(is_in_frame(state.value(), cur_frame_));
            assert(queue_.empty());
            ++stats_.obligations;
            queue_.emplace_back(std::move(state.value()), cur_frame_);
            for (;;) {
                assert(!queue_.empty());
                auto path = expansion_loop();
                if (path.has_value()) {
                    ++stats_.paths;
                    auto idx = accept_path_(path.value());
                    if (idx.has_value()) {
                        queue_.erase(
                            queue_.begin() + idx.value(),
                            queue_.end());
                        if (idx.value() == 0u) {
                            break;
                        }
                        continue;
                    }
                    std::cout << "ic3: found a property-satisfying path."
                              << std::endl;
                    stats_.total_time.stop();
                    return path;
                } else {
                    break;
                }
            }
        }
    }

    [[nodiscard]]
    const Statistics& get_statistics() const
    {
        return stats_;
    }

    [[nodiscard]]
    const _FramesRefiner frames_refiner() const
    {
        return refine_frames_;
    }

private:
    [[nodiscard]]
    std::optional<Path> expansion_loop()
    {
        assert(!queue_.empty());
        obligation_type* obligation = &queue_.back();
        for (;;) {
            assert(obligation->frame >= 1u);

            // run trial down until hitting only blocked states or finding a
            // avoid state
        expansion_start:
#if POLICE_VERBOSITY
            if (should_print_stats()) {
                last_printed_ = stats_.clauses;
                stats_.print_status_line(std::cout);
            }
#endif
            POLICE_DEBUG_MSG(
                "expanding "
                << value_vector_printer(*vspace_, obligation->state) << " (obl="
                << obligation << ", frame=" << obligation->frame << ")" << "\n")

            {
                auto successors = generate_successors_(obligation->state);
                for (const auto& succ : successors) {
                    obligation->label = succ.label;
                    POLICE_DEBUG_MSG(
                        " - successor via action "
                        << succ.label << ": "
                        << value_vector_printer(*vspace_, succ.state));
                    if (!is_in_frame(succ.state, obligation->frame - 1)) {
                        POLICE_DEBUG_MSG(
                            " -> rejected from frame "
                            << (obligation->frame - 1) << "\n");
                        continue;
                    } else if (obligation->frame == 1u) {
                        POLICE_DEBUG_MSG(" -> *satisfies avoid condition*\n");
                        auto path = trace_path();
                        path.emplace_back(std::move(succ.state));
                        return path;
                    } else {
                        POLICE_DEBUG_MSG(" -> new obligation\n");
                        ++stats_.obligations;
                        queue_.emplace_back(
                            std::move(succ.state),
                            obligation->frame - 1);
                        obligation = &queue_.back();
                        goto expansion_start;
                    }
                }

                assert(std::ranges::all_of(successors, [&](const auto& succ) {
                    return !is_in_frame(succ.state, obligation->frame - 1);
                }));
            }

            POLICE_DEBUG_MSG(
                "backtracking from "
                << value_vector_printer(*vspace_, obligation->state) << " (obl="
                << obligation << ", frame=" << obligation->frame << ")" << "\n")

            POLICE_DEBUG_IFMSG(
                !is_in_frame_(obligation->state, obligation->frame),
                "  skipping refinement: state already blocked from frame\n")

            // backtrack: all successors blocked -> refine frame
            // check if state is rejected already
            if (is_in_frame(obligation->state, obligation->frame)) {
                POLICE_DEBUG_MSG("starting frame refinement\n")

                const size_t blocked_at_frame =
                    block_state_at_frame(obligation->state, obligation->frame);

                assert(blocked_at_frame >= obligation->frame);
                assert(!is_in_frame(obligation->state, blocked_at_frame));
                // check and repush state for expansion
                if (rescheduling_enabled_ &&
                    blocked_at_frame + 1 < cur_frame_) {
                    ++stats_.rescheduled;
                    obligation->frame = blocked_at_frame + 1;
                    goto expansion_start;
                }
            } else if (rescheduling_enabled_) {
                // state is rejected -> find frame where it can be rescheduled
                for (++obligation->frame;
                     obligation->frame <= cur_frame_ &&
                     is_in_frame(obligation->state, obligation->frame);
                     ++obligation->frame) {
                }
                if (obligation->frame <= cur_frame_) {
                    ++stats_.rescheduled;
                    goto expansion_start;
                }
            }

            // pop obligation if not reexpanding
        backtrack_start:
            queue_.pop_back();
            if (queue_.empty()) {
                return std::nullopt;
            }
            obligation = &queue_.back();
            // check if state is now blocked due to earlier frame refinement
            if (!is_in_frame(obligation->state, obligation->frame)) {
                goto backtrack_start;
            }
        }
        POLICE_UNREACHABLE();
    }

    size_t block_state_at_frame(const flat_state& state, size_t frame)
    {
        auto stop_watch = stats_.refinement_time.scope();
        ++stats_.clauses;
        return refine_frames_(state, frame);
    }

    [[nodiscard]]
    Path trace_path() const
    {
        Path result;
        for (auto i = 0u; i < queue_.size(); ++i) {
            result.emplace_back(queue_[i].state, queue_[i].label);
        }
        return result;
    }

    [[nodiscard]]
    bool is_in_frame(const flat_state& state, size_t frame_idx) const
    {
        if (frame_idx == 0u) {
            return static_cast<bool>(expressions::evaluate(
                avoid_,
                [&state](size_t var) { return state[var]; }));
        }
        return is_in_frame_(state, frame_idx);
    }

    [[nodiscard]]
    bool should_print_stats() const
    {
        return last_printed_ != stats_.clauses &&
               last_printed_ % FREQUENCY == 0;
    }

    static constexpr size_t FREQUENCY = 100;

    vector<_detail::Obligation> queue_;
    Statistics stats_;
    size_t cur_frame_ = 0;
    mutable size_t last_printed_ = 0;

    const VariableSpace* vspace_;

    _StartGenerator generate_start_;
    _SuccessorGenerator generate_successors_;
    _FramesInitializer init_frames_;
    _FrameAdder add_frame_;
    _FramesChecker is_in_frame_;
    _FramesRefiner refine_frames_;
    _PathChecker accept_path_;
    expressions::Expression avoid_;

    bool rescheduling_enabled_ = false;
};

} // namespace police::ic3
