/*
 * Copyright (c) Meta Platforms, Inc. and affiliates.
 *
 * This source code is licensed under the MIT license found in the
 * LICENSE file in the root directory of this source tree.
 */

// -*- c++ -*-

#include <faiss/IndexBinaryFromFloat.h>

#include <faiss/impl/FaissAssert.h>
#include <faiss/utils/utils.h>
#include <algorithm>
#include <memory>

namespace faiss {

IndexBinaryFromFloat::IndexBinaryFromFloat() = default;

IndexBinaryFromFloat::IndexBinaryFromFloat(Index* index)
        : IndexBinary(index->d), index(index), own_fields(false) {
    is_trained = index->is_trained;
    ntotal = index->ntotal;
}

IndexBinaryFromFloat::~IndexBinaryFromFloat() {
    if (own_fields) {
        delete index;
    }
}

void IndexBinaryFromFloat::add(idx_t n, const uint8_t* x) {
    constexpr idx_t bs = 32768;
    std::unique_ptr<float[]> xf(new float[bs * d]);

    for (idx_t b = 0; b < n; b += bs) {
        idx_t bn = std::min(bs, n - b);
        binary_to_real(bn * d, x + b * code_size, xf.get());

        index->add(bn, xf.get());
    }
    ntotal = index->ntotal;
}

void IndexBinaryFromFloat::reset() {
    index->reset();
    ntotal = index->ntotal;
}

void IndexBinaryFromFloat::search(
        idx_t n,
        const uint8_t* x,
        idx_t k,
        int32_t* distances,
        idx_t* labels,
        const SearchParameters* params) const {
    FAISS_THROW_IF_NOT_MSG(
            !params, "search params not supported for this index");
    FAISS_THROW_IF_NOT(k > 0);

    constexpr idx_t bs = 32768;
    std::unique_ptr<float[]> xf(new float[bs * d]);
    std::unique_ptr<float[]> df(new float[bs * k]);

    for (idx_t b = 0; b < n; b += bs) {
        idx_t bn = std::min(bs, n - b);
        binary_to_real(bn * d, x + b * code_size, xf.get());

        index->search(bn, xf.get(), k, df.get(), labels + b * k);
        for (int i = 0; i < bn * k; ++i) {
            distances[b * k + i] = int32_t(std::round(df[i] / 4.0));
        }
    }
}

void IndexBinaryFromFloat::train(idx_t n, const uint8_t* x) {
    std::unique_ptr<float[]> xf(new float[n * d]);
    binary_to_real(n * d, x, xf.get());

    index->train(n, xf.get());
    is_trained = true;
    ntotal = index->ntotal;
}

} // namespace faiss
