/*
 * Copyright (c) Meta Platforms, Inc. and affiliates.
 *
 * This source code is licensed under the MIT license found in the
 * LICENSE file in the root directory of this source tree.
 */

// -*- c++ -*-

#include <faiss/IndexLattice.h>
#include <faiss/impl/FaissAssert.h>
#include <faiss/utils/distances.h>
#include <faiss/utils/hamming.h> // for the bitstring routines

namespace faiss {

IndexLattice::IndexLattice(idx_t d, int nsq, int scale_nbit, int r2)
        : IndexFlatCodes(0, d, METRIC_L2),
          nsq(nsq),
          dsq(d / nsq),
          zn_sphere_codec(dsq, r2),
          scale_nbit(scale_nbit) {
    FAISS_THROW_IF_NOT(d % nsq == 0);

    lattice_nbit = 0;
    while (!(((uint64_t)1 << lattice_nbit) >= zn_sphere_codec.nv)) {
        lattice_nbit++;
    }

    int total_nbit = (lattice_nbit + scale_nbit) * nsq;

    code_size = (total_nbit + 7) / 8;

    is_trained = false;
}

void IndexLattice::train(idx_t n, const float* x) {
    // compute ranges per sub-block
    trained.resize(nsq * 2);
    float* mins = trained.data();
    float* maxs = trained.data() + nsq;
    for (int sq = 0; sq < nsq; sq++) {
        mins[sq] = HUGE_VAL;
        maxs[sq] = -1;
    }

    for (idx_t i = 0; i < n; i++) {
        for (int sq = 0; sq < nsq; sq++) {
            float norm2 = fvec_norm_L2sqr(x + i * d + sq * dsq, dsq);
            if (norm2 > maxs[sq])
                maxs[sq] = norm2;
            if (norm2 < mins[sq])
                mins[sq] = norm2;
        }
    }

    for (int sq = 0; sq < nsq; sq++) {
        mins[sq] = sqrtf(mins[sq]);
        maxs[sq] = sqrtf(maxs[sq]);
    }

    is_trained = true;
}

/* The standalone codec interface */
size_t IndexLattice::sa_code_size() const {
    return code_size;
}

void IndexLattice::sa_encode(idx_t n, const float* x, uint8_t* codes) const {
    const float* mins = trained.data();
    const float* maxs = mins + nsq;
    int64_t sc = int64_t(1) << scale_nbit;

#pragma omp parallel for
    for (idx_t i = 0; i < n; i++) {
        BitstringWriter wr(codes + i * code_size, code_size);
        const float* xi = x + i * d;
        for (int j = 0; j < nsq; j++) {
            float nj = (sqrtf(fvec_norm_L2sqr(xi, dsq)) - mins[j]) * sc /
                    (maxs[j] - mins[j]);
            if (nj < 0)
                nj = 0;
            if (nj >= sc)
                nj = sc - 1;
            wr.write((int64_t)nj, scale_nbit);
            wr.write(zn_sphere_codec.encode(xi), lattice_nbit);
            xi += dsq;
        }
    }
}

void IndexLattice::sa_decode(idx_t n, const uint8_t* codes, float* x) const {
    const float* mins = trained.data();
    const float* maxs = mins + nsq;
    float sc = int64_t(1) << scale_nbit;
    float r = sqrtf(zn_sphere_codec.r2);

#pragma omp parallel for
    for (idx_t i = 0; i < n; i++) {
        BitstringReader rd(codes + i * code_size, code_size);
        float* xi = x + i * d;
        for (int j = 0; j < nsq; j++) {
            float norm =
                    (rd.read(scale_nbit) + 0.5) * (maxs[j] - mins[j]) / sc +
                    mins[j];
            norm /= r;
            zn_sphere_codec.decode(rd.read(lattice_nbit), xi);
            for (int l = 0; l < dsq; l++) {
                xi[l] *= norm;
            }
            xi += dsq;
        }
    }
}

} // namespace faiss
