#pragma once

#include "police/base_types.hpp"
#include "police/storage/variable_space.hpp"
#include "police/verifiers/ic3/cube.hpp"
#include "police/verifiers/ic3/sat_based/generalizer/utils.hpp"

#include <cassert>
#include <concepts>
#include <cstdint>
#include <numeric>
#include <type_traits>

namespace police::ic3 {

namespace details {
template <typename T>
concept unsat_core_extractable = requires(T t) {
    { t.get_unsat_core() } -> std::convertible_to<Cube>;
};

template <typename T>
constexpr bool unsat_core_extractable_v = false;

template <unsat_core_extractable T>
constexpr bool unsat_core_extractable_v<T> = true;

template <typename T>
using unsat_core_extractable_t =
    std::integral_constant<bool, unsat_core_extractable_v<T>>;
} // namespace details

template <typename GoalChecker>
class GreedyGeneralizer {
public:
    constexpr static std::uint32_t CATEGORIAL_MAX_WIDTH = 100;

    GreedyGeneralizer(const VariableSpace* variables, GoalChecker goal)
        : goal_(std::move(goal))
        , variables_(variables)
        , order_(variables->size())
    {
        std::iota(order_.begin(), order_.end(), 0);
    }

    GreedyGeneralizer(
        const VariableSpace* variables,
        GoalChecker goal,
        vector<size_t> variable_order)
        : goal_(std::move(goal))
        , variables_(variables)
        , order_(std::move(variable_order))
    {
    }

    GreedyGeneralizer(
        const VariableSpace* variables,
        GoalChecker goal,
        vector<size_t> variable_order,
        vector<bool>)
        : goal_(std::move(goal))
        , variables_(variables)
        , order_(std::move(variable_order))
    {
    }

    template <typename Sat>
    void operator()(Sat&& sat, Cube& reason, size_t frame_id)
    {
        using police::operator<<;
        assert(sat.is_blocked(reason, frame_id).first);
        generalize(std::false_type(), std::forward<Sat>(sat), reason, frame_id);
        assert(sat.is_blocked(reason, frame_id).first);
        assert(!goal_(reason));
    }

private:
    template <typename Sat>
    void generalize(std::true_type, Sat&& sat, Cube& reason, size_t frame_id)
    {
        auto check_sat = [&](Cube& cube) {
            if (goal_(cube)) {
                return false;
            }
            const bool blocked = sat.is_blocked(cube, frame_id).first;
            if (blocked) {
                update_from_unsat_core(
                    *variables_,
                    goal_,
                    reason,
                    sat.get_unsat_core());
            }
            return blocked;
        };
        return generalize<true>(std::move(check_sat), reason);
    }

    template <typename Sat>
    void generalize(std::false_type, Sat&& sat, Cube& reason, size_t frame_id)
    {
        auto check_sat = [&](Cube& cube) {
            return !goal_(cube) && sat.is_blocked(cube, frame_id).first;
        };
        return generalize<false>(std::move(check_sat), reason);
    }

    template <bool UC, typename CheckSat>
    void generalize(CheckSat&& check, Cube& reason)
    {
        for (auto var_id : order_) {
            auto it = reason.find(var_id);
            if (it != reason.end()) {
                const auto& t = variables_->get_type(var_id);
                auto domain = get_domain(t);
                std::swap(const_cast<Interval&>(it->second), domain);
                if (check(reason)) {
                    if constexpr (UC) {
                        // iterator might have gotten invalidated due to unsat
                        // core update; that update might have even removed the
                        // var_id from reason already
                        it = reason.find(var_id);
                        if (it != reason.end()) {
                            reason.erase(it);
                        }
                    } else {
                        reason.erase(it);
                    }
                } else {
                    // iterator must still be valid
                    std::swap(const_cast<Interval&>(it->second), domain);
                }
            }
        }
    }

#if 0
    template <typename CheckSat>
    void generalize_interval(CheckSat check, Cube& reason, size_t var_id)
    {
        const auto& t = variables_->get_type(var_id);
        assert(is_bounded_int_type(t));
        assert(var_id < reason.size());
        bring_down_lb_remove(check, reason, var_id, get_lb(t));
        bring_up_ub_remove(check, reason, var_id, get_ub(t));
    }

    template <typename CheckSat>
    void
    bring_down_lb_add(CheckSat check, Cube& reason, size_t var_id, int_t lb)
    {
        auto& iset = reason[var_id];
        Interval backup = iset;
        const int_t center = iset.lb;
        for (int_t step = 1;; step <<= 1) {
            const int_t new_lb = std::max(lb, center - step);
            iset.lb = Value(new_lb);
            if (!check(reason)) {
                iset.swap(backup);
                break;
            }
            if (new_lb == lb) {
                break;
            }
            backup = iset;
        }
    }

    template <typename CheckSat>
    void
    bring_down_lb_remove(CheckSat check, Cube& reason, size_t var_id, int_t lb)
    {
        auto& iset = reason[var_id];
        Interval backup = iset;
        iset.lb = lb;
        for (; !check(reason);) {
            lb = (static_cast<int_t>(backup.lb) + lb) / 2;
            assert(lb >= static_cast<int_t>(backup.lb));
            iset.lb = lb;
        }
    }

    template <typename CheckSat>
    void bring_up_ub(CheckSat check, Cube& reason, size_t var_id, int_t ub)
    {
        auto& iset = reason[var_id];
        Interval backup = iset;
        const int_t center = iset.ub;
        for (int_t step = 1;; step <<= 1) {
            const int_t new_ub = std::min(ub, center + step);
            iset.ub = Value(new_ub);
            if (!check(reason)) {
                iset.swap(backup);
                break;
            }
            if (new_ub == ub) {
                break;
            }
            backup = iset;
        }
    }

    template <typename CheckSat>
    void
    bring_up_ub_remove(CheckSat check, Cube& reason, size_t var_id, int_t ub)
    {
        auto& iset = reason[var_id];
        Interval backup = iset;
        iset.ub = ub;
        for (; !check(reason);) {
            ub = (static_cast<int_t>(backup.ub) + ub) / 2;
            assert(ub <= static_cast<int_t>(backup.ub));
            iset.ub = ub;
        }
    }
#endif

    GoalChecker goal_;
    const VariableSpace* variables_;
    vector<size_t> order_;
    vector<bool> is_categorial_;
};

} // namespace police::ic3
