#include "police/verifiers/search/state_novelty.hpp"

#include "police/storage/variable_space.hpp"
#include "police/verifiers/search/novelty_queue.hpp"

#include <type_traits>

namespace police::search {

StateFeatureIterator::StateFeatureIterator(
    const vector<search::feature_t>* features)
    : features_(features)
    , indices_({0})
{
}

StateFeatureIterator& StateFeatureIterator::operator++()
{
    size_t j = 0;
    int i = indices_.size() - 1;
    for (; i >= 0 && ++indices_[i] == features_->size() - j; --i, ++j) {
    }
    if (i < 0) {
        indices_.push_back(0);
        for (j = 0; j < indices_.size(); ++j) {
            indices_[j] = j;
        }
    } else {
        for (; i + 1 < static_cast<int>(indices_.size()); ++i) {
            indices_[i + 1] = indices_[i] + 1;
        }
    }
    return *this;
}

StateFeatureIterator StateFeatureIterator::operator++(int)
{
    auto temp(*this);
    ++*this;
    return temp;
}

bool StateFeatureIterator::operator==(const StateFeatureIterator& other) const
{
    return indices_ == other.indices_;
}

StateFeatureIterator::value_type StateFeatureIterator::operator*() const
{
    value_type res;
    res.reserve(indices_.size());
    for (auto i = 0u; i < indices_.size(); ++i) {
        res.push_back(features_->at(indices_[i]));
    }
    return res;
}

StateFeatureRange::StateFeatureRange(
    vector<search::feature_t> features,
    size_t max_width)
    : features_(std::move(features))
    , end_(&features_)
{
    end_.indices_.resize(
        features_.size() < max_width ? features_.size() : max_width);
    for (int i = end_.indices_.size(), j = features_.size(); i >= 0; --i, --j) {
        end_.indices_[i] = j;
    }
}

StateFeatureIterator StateFeatureRange::begin() const
{
    return StateFeatureIterator{&features_};
}

const StateFeatureIterator& StateFeatureRange::end() const
{
    return end_;
}

namespace {
int_t get_lower_bound(const VariableType& var_type)
{
    return std::visit(
        [](auto&& t) -> int_t {
            using T = std::decay_t<decltype(t)>;
            if constexpr (std::is_same_v<T, BoundedIntType>) {
                return t.lower_bound;
            } else {
                return 0;
            }
        },
        var_type);
}

size_t get_interval_width(const VariableType& var_type)
{
    return std::visit(
        [](auto&& t) -> size_t {
            using T = std::decay_t<decltype(t)>;
            if constexpr (std::is_same_v<T, BoundedIntType>) {
                return std::abs(t.upper_bound - t.lower_bound);
            } else {
                return 0;
            }
        },
        var_type);
}
} // namespace

StateNovelty::StateNovelty(const VariableSpace& vspace, size_t max_width)
    : max_width_(max_width)
{
    size_t idx = 0;
    size_t offset = 0;
    for (auto it = vspace.begin(); it != vspace.end(); ++it, ++idx) {
        if (it->type.is_bool()) {
            vars_.emplace_back(idx, offset);
            offset += 2;
        } else if (it->type.is_bounded_int()) {
            vars_.emplace_back(idx, offset - get_lower_bound(it->type));
            offset += get_interval_width(it->type);
        }
    }
    assert(!vars_.empty());
}

StateFeatureRange StateNovelty::operator()(const CompressedState& state) const
{
    vector<search::feature_t> features;
    features.reserve(vars_.size());
    for (const auto& [var, off] : vars_) {
        features.push_back(static_cast<int>(state[var]) + off);
    }
    return StateFeatureRange(std::move(features), max_width_);
}

} // namespace police::search
