#pragma once

#include "police/expressions/variable.hpp"
#include "police/smt.hpp"
#include "police/storage/flat_state.hpp"
#include "police/storage/path.hpp"
#include "police/utils/rng.hpp"
#include "police/verifiers/search/concepts.hpp"

#include <limits>
#include <optional>
#include <ostream>

namespace police::sampler {

struct Statistics {
    unsigned long long total = 0;
    unsigned long long avoid = 0;
    unsigned long long goal = 0;
    unsigned long long terminal = 0;
    unsigned long long limit = 0;
    unsigned long long length = 0;
    unsigned long long length_avoid = 0;
    unsigned long long length_goal = 0;
    unsigned long long length_terminal = 0;
    unsigned long long initial_state_resets = 0;
};

template <
    search::successor_generator<flat_state> SuccessorsGenerator,
    typename AvoidChecker,
    typename GoalChecker>
class Sampler {
public:
    Sampler(
        std::shared_ptr<SMT> initial_states,
        SuccessorsGenerator successor_generator,
        AvoidChecker avoid_checker,
        GoalChecker goal_checker,
        RNG rng,
        size_t num,
        size_t depth_limit = std::numeric_limits<size_t>::max())
        : initial_states_(std::move(initial_states))
        , successor_generator_(std::move(successor_generator))
        , avoid_checker_(std::move(avoid_checker))
        , goal_checker_(std::move(goal_checker))
        , rng_(std::move(rng))
        , num_(num)
        , depth_limit_(depth_limit)
    {
        initial_states_->push_snapshot();
    }

    Statistics operator()()
    {
        Statistics result;
        reset_initial_states();
        vector<StateLabelPair> successors;
        for (; result.total < num_; ++result.total) {
            auto istate = get_initial_state();
            if (!istate.has_value()) {
                if (result.total == 0) {
                    return result;
                }
                ++result.initial_state_resets;
                --result.total;
                reset_initial_states();
                continue;
            }
            std::cout << "sample #" << result.total << std::endl;
            flat_state state(std::move(istate.value()));
            size_t d = 0;
            for (; d < depth_limit_; ++d) {
                if (avoid_checker_(state)) {
                    ++result.avoid;
                    result.length_avoid += d + 1;
                    std::cout << "=> sample #" << result.total << ": "
                              << " avoid satisfied" << std::endl;
                    break;
                }
                if (goal_checker_(state)) {
                    ++result.goal;
                    result.length_goal += d + 1;
                    std::cout << "=> sample #" << result.total << ": "
                              << " goal satisfied" << std::endl;
                    break;
                }
                auto srange = successor_generator_(state);
                for (auto it = srange.begin(); it != srange.end(); ++it) {
                    const auto& succ = *it;
                    successors.emplace_back(succ.state, succ.label);
                }
                if (successors.empty()) {
                    ++result.terminal;
                    result.length_terminal += d + 1;
                    std::cout << "=> sample #" << result.total << ": "
                              << " terminal" << std::endl;
                    break;
                }
                size_t selected = rng_(0, successors.size() - 1);
                state = std::move(successors[selected].state);
                successors.clear();
            }
            result.limit += d >= depth_limit_;
            result.length += d;
        }

        return result;
    }

private:
    void reset_initial_states()
    {
        initial_states_->pop_snapshot();
        initial_states_->push_snapshot();
    }

    std::optional<flat_state> get_initial_state() const
    {
        auto status = initial_states_->solve();
        if (status == SMT::Status::UNSAT) {
            return std::nullopt;
        }

        auto model = initial_states_->get_model();
        flat_state state(model.size());
        state[0] = model[0];
        expressions::Expression blocker = expressions::Variable(0) != model[0];
        for (size_t i = 0; i < model.size(); ++i) {
            state[i] = model[i];
            blocker = blocker || expressions::Variable(i) != model[i];
        }
        initial_states_->add_constraint(std::move(blocker));

        return state;
    }

    std::shared_ptr<SMT> initial_states_;
    SuccessorsGenerator successor_generator_;
    AvoidChecker avoid_checker_;
    GoalChecker goal_checker_;
    RNG rng_;
    unsigned long long num_;
    size_t depth_limit_;
};

} // namespace police::sampler

namespace police {
std::ostream& operator<<(std::ostream& out, const sampler::Statistics& stats);
}
