#pragma once

#include "police/action.hpp"
#include "police/model.hpp"
#include "police/storage/path.hpp"
#include "police/verifiers/ic3/syntactic/policy_reasoner.hpp"
#include "police/verifiers/ic3/syntactic/policy_reasons.hpp"

#include <optional>

#define ENABLE_DEBUG_PRINTS 0

#if ENABLE_DEBUG_PRINTS
#include "police/utils/io.hpp"
#endif

namespace police::ic3::syntactic {

enum class PolicyPathValidationDirection {
    FORWARD,
    BACKWARD,
};

template <typename Policy>
class PolicyPathChecker {
public:
    PolicyPathChecker(
        PolicyPathValidationDirection dir,
        const Policy policy,
        PolicyReasons* reasons,
        const Model* model,
        std::shared_ptr<syntactic::PolicyReasoner> reasoner)
        : reasoner_(std::move(reasoner))
        , model_(model)
        , policy_(std::move(policy))
        , reasons_(reasons)
        , direction_(dir)
    {
    }

    [[nodiscard]]
    std::optional<size_t> operator()(const Path& path) const
    {
        if (path.size() <= 1u) {
            return std::nullopt;
        }
#if ENABLE_DEBUG_PRINTS
        std::cout << "checking the following path (length=" << path.size()
                  << ")\n";
        for (size_t i = 0; i < path.size() - 1; ++i) {
            std::cout << print_sequence(path[i].state) << "\n";
            std::cout << "action " << path[i].label << "\n";
        }
        std::cout << print_sequence(path.back().state) << std::endl;
#endif
        int i = 0;
        int n = path.size() - 1;
        int d = 1;
        if (direction_ == PolicyPathValidationDirection::BACKWARD) {
            i = n - 1;
            n = -1;
            d = -1;
        }
        for (; i != n; i += d) {
            if (path[i].label == SILENT_ACTION) {
                continue;
            }
            const size_t chosen = policy_(path[i].state);
#if ENABLE_DEBUG_PRINTS
            std::cout << "Step " << i << " => policy selection: " << chosen
                      << " | path label: " << path[i].label << std::endl;
#endif
            if (chosen != path[i].label) {
                reasoner_->prepare(path[i].state);
                const auto options =
                    reasoner_->get_reason(path[i].state, {}, path[i].label);

#if ENABLE_DEBUG_PRINTS
                // std::cout << "options:\n";
                // for (const auto& x : options) {
                //     std::cout << print_sequence(x) << std::endl;
                // }
#endif

                assert(!options.empty());
                Cube cube;
                for (const auto& hs : options.front()) {
                    cube.shrink(
                        hs.variable_id,
                        Interval(
                            hs.type == VariableCondition::EQUALITY ||
                                    hs.type == VariableCondition::LOWER_BOUND
                                ? path[i].state[hs.variable_id]
                                : model_->variables[hs.variable_id]
                                      .type.get_lower_bound(),
                            hs.type == VariableCondition::EQUALITY ||
                                    hs.type == VariableCondition::UPPER_BOUND
                                ? path[i].state[hs.variable_id]
                                : model_->variables[hs.variable_id]
                                      .type.get_upper_bound()));
                }

#if ENABLE_DEBUG_PRINTS
                std::cout << "blocking label " << path[i].label << " for "
                          << cube.dump(model_->variables) << std::endl;
#endif

                reasons_->block(std::move(cube), path[i].label);
                assert(reasons_->is_blocked(path[i].state, path[i].label));
                return i;
            }
        }
        return std::nullopt;
    }

private:
    std::shared_ptr<syntactic::PolicyReasoner> reasoner_;
    const Model* model_;
    const Policy policy_;
    PolicyReasons* reasons_;
    PolicyPathValidationDirection direction_;
};

} // namespace police::ic3::syntactic

#undef ENABLE_DEBUG_PRINTS
