#pragma once

#include "police/expressions/expression.hpp"
#include "police/storage/flat_state.hpp"
#include "police/storage/segmented_vector.hpp"
#include "police/verifiers/ic3/concepts.hpp"
#include "police/verifiers/ic3/epic3.hpp"
#include "police/verifiers/ic3/frame.hpp"
#include "police/verifiers/ic3/start_generator.hpp"

namespace police::ic3 {

namespace _detail::smt {

template <concepts::SatInterface _Sat>
class CubeInserter : private ic3::CubeInserter {
public:
    CubeInserter(
        segmented_vector<Frame>* frames,
        CubeDatabase* cube_db,
        _Sat sat_interface,
        StartGenerator start_generator)
        : police::ic3::CubeInserter(cube_db, frames)
        , sat_(std::move(sat_interface))
        , start_generator_(std::move(start_generator))
    {
    }

    void operator()(Cube cube, size_t frame_id)
    {
        ic3::CubeInserter::operator()(cube, frame_id);
        sat_.set_blocked(cube, frame_id);
        if (frame_id == frames_->size()) {
            start_generator_.set_blocked(cube);
        }
    }

private:
    _Sat sat_;
    StartGenerator start_generator_;
};

template <concepts::SatInterface _SAT>
class FramesAdder {
public:
    using sat_interface_type = _SAT;

    explicit FramesAdder(
        segmented_vector<Frame>* frames,
        CubeDatabase* cube_db,
        sat_interface_type sat_interface,
        StartGenerator start_generator)
        : block_cube_(frames, cube_db, sat_interface, start_generator)
        , sat_interface_(std::move(sat_interface))
        , frames_(frames)
        , cube_db_(cube_db)
        , start_generator_(std::move(start_generator))
    {
    }

    bool operator()()
    {
        // add empty frame
        frames_->emplace_back(cube_db_, frames_->size() + 1);
        start_generator_.clear();

        // push cubes from lowest to highest frame
        auto& frames = *frames_;
        assert(frames.size() > 0);
        Cube reason;
        for (size_t k = 1; k < frames.size(); ++k) {
            auto& frame = frames[k - 1];
            for (auto it = frame.begin(); it != frame.end();) {
                const auto cube_id = *it;
                reason = cube_db_->get_cube(cube_id);
                const auto blocked = sat_interface_.is_blocked(reason, k + 1);
                if (blocked.first) {
                    assert(blocked.second > k);
                    assert(cube_db_->get_cube(cube_id) <= reason);
                    // block the returned cube at target frame
                    block_cube_(std::move(reason), blocked.second);
                    // iterator might have been invalidated due to
                    // subsumption-based pruning -> search for next cube (frame
                    // is sorted by cube id, can use binary search)
                    it = std::lower_bound(frame.begin(), frame.end(), cube_id);
                } else {
                    ++it;
                }
            }
            if (frame.empty()) {
                // found invariant
                return true;
            }
        }

        // cleanup if some cubes were pruned, being subsumed after cube pushing
        if (cube_db_->num_dangling_cubes() > 0u) {
            collect_garbage();
        }

        return false;
    }

private:
    void collect_garbage()
    {
        assert(cube_db_->num_dangling_cubes() != 0u);
        // prune dangling cubes (cubes without frame association)
        cube_db_->collect_garbage();
        // clear frame references
        auto& frames = *frames_;
        for (auto frame = 0u; frame < frames.size(); ++frame) {
            frames[frame].clear();
        }
        // clear sat interface
        sat_interface_.clear_frames();
        // re-insert non-dangling cubes at their new (old if didn't change after
        // pushing) frame
        for (auto cube_id = 0u; cube_id < cube_db_->size(); ++cube_id) {
            const size_t frame_id = cube_db_->get_frame(cube_id);
            const auto& cube = cube_db_->get_cube(cube_id);
            assert(frame_id > 0u);
            auto& frame = frames[frame_id - 1];
            frame.data()->push_back(cube_id);
            sat_interface_.set_blocked(cube, frame_id);
        }
    }

    CubeInserter<_SAT> block_cube_;
    sat_interface_type sat_interface_;
    segmented_vector<Frame>* frames_;
    CubeDatabase* cube_db_;
    StartGenerator start_generator_;
};

template <
    concepts::SatInterface _SatInterface,
    concepts::ReasonGeneralizer<_SatInterface> _ReasonGeneralizer>
class FrameRefiner {
public:
    FrameRefiner(
        segmented_vector<Frame>* frames,
        CubeDatabase* cube_db,
        StartGenerator start_generator,
        _SatInterface sat,
        _ReasonGeneralizer generalizer)
        : frames_(frames)
        , cube_db_(cube_db)
        , block_cube_(frames, cube_db, sat, start_generator)
        , start_generator_(std::move(start_generator))
        , sat_(std::move(sat))
        , generalizer_(std::move(generalizer))
    {
    }

    size_t operator()(const flat_state& state, size_t frame)
    {
        Cube reason(state);
        {
            // auto timer = stats_.generalization_time.scope();
            generalizer_(sat_, reason, frame);
        }
        frame = find_insertion_frame(reason, frame);
        block_cube_(std::move(reason), frame);
        return frame;
    }

private:
    size_t find_insertion_frame(Cube& cube, size_t first_frame)
    {
        for (; first_frame < frames_->size();) {
            const auto blocked = sat_.is_blocked(cube, first_frame + 1);
            if (!blocked.first) {
                break;
            }
            assert(blocked.second > first_frame);
            first_frame = blocked.second;
        }
        return first_frame;
    }

    segmented_vector<Frame>* frames_;
    CubeDatabase* cube_db_;
    CubeInserter<_SatInterface> block_cube_;
    StartGenerator start_generator_;
    _SatInterface sat_;
    _ReasonGeneralizer generalizer_;
};
} // namespace _detail::smt

template <
    search::successor_generator<flat_state> _SuccessorGenerator,
    concepts::SatInterface _SatInterface,
    concepts::ReasonGeneralizer<_SatInterface> _ReasonGeneralizer>
class ExplicitStateIC3smt
    : public ExplicitStateIC3<
          StartGenerator,
          _SuccessorGenerator,
          FramesInitializer,
          _detail::smt::FramesAdder<_SatInterface>,
          FramesChecker,
          _detail::smt::FrameRefiner<_SatInterface, _ReasonGeneralizer>> {
public:
    using sat_interface_type = _SatInterface;
    using generalizer_type = _ReasonGeneralizer;

    ExplicitStateIC3smt(
        const VariableSpace* vspace,
        expressions::Expression goal,
        StartGenerator start_generator,
        _SuccessorGenerator successor_generator,
        _SatInterface sat,
        _ReasonGeneralizer generalize,
        bool obligation_rescheduling)
        : ExplicitStateIC3smt(
              vspace,
              std::move(goal),
              std::move(start_generator),
              std::move(successor_generator),
              std::move(sat),
              std::move(generalize),
              obligation_rescheduling,
              new segmented_vector<Frame>(),
              new CubeDatabase())
    {
    }

private:
    ExplicitStateIC3smt(
        const VariableSpace* vspace,
        expressions::Expression goal,
        StartGenerator start_generator,
        _SuccessorGenerator successor_generator,
        _SatInterface sat,
        _ReasonGeneralizer generalize,
        bool obligation_rescheduling,
        segmented_vector<Frame>* frames,
        CubeDatabase* cubes)
        : ExplicitStateIC3<
              StartGenerator,
              _SuccessorGenerator,
              FramesInitializer,
              _detail::smt::FramesAdder<_SatInterface>,
              FramesChecker,
              _detail::smt::FrameRefiner<_SatInterface, _ReasonGeneralizer>>(
              vspace,
              start_generator,
              successor_generator,
              FramesInitializer(cubes, frames),
              _detail::smt::FramesAdder<_SatInterface>(
                  frames,
                  cubes,
                  sat,
                  start_generator),
              FramesChecker(cubes, frames),
              _detail::smt::FrameRefiner(
                  frames,
                  cubes,
                  start_generator,
                  sat,
                  generalize),
              goal,
              obligation_rescheduling)
        , frames_(frames)
        , cubes_(cubes)
    {
    }

    std::unique_ptr<segmented_vector<Frame>> frames_;
    std::unique_ptr<CubeDatabase> cubes_;
};

} // namespace police::ic3
