/*
 * Copyright (c) Meta Platforms, Inc. and affiliates.
 *
 * This source code is licensed under the MIT license found in the
 * LICENSE file in the root directory of this source tree.
 */

// -*- c++ -*-

#include "Clustering_c.h"
#include <faiss/Clustering.h>
#include <faiss/Index.h>
#include <vector>
#include "macros_impl.h"

extern "C" {

using faiss::Clustering;
using faiss::ClusteringIterationStats;
using faiss::ClusteringParameters;
using faiss::Index;

DEFINE_GETTER(Clustering, int, niter)
DEFINE_GETTER(Clustering, int, nredo)
DEFINE_GETTER(Clustering, int, verbose)
DEFINE_GETTER(Clustering, int, spherical)
DEFINE_GETTER(Clustering, int, int_centroids)
DEFINE_GETTER(Clustering, int, update_index)
DEFINE_GETTER(Clustering, int, frozen_centroids)

DEFINE_GETTER(Clustering, int, min_points_per_centroid)
DEFINE_GETTER(Clustering, int, max_points_per_centroid)

DEFINE_GETTER(Clustering, int, seed)
DEFINE_GETTER(Clustering, size_t, decode_block_size)

/// getter for d
DEFINE_GETTER(Clustering, size_t, d)

/// getter for k
DEFINE_GETTER(Clustering, size_t, k)

DEFINE_GETTER(ClusteringIterationStats, float, obj)
DEFINE_GETTER(ClusteringIterationStats, double, time)
DEFINE_GETTER(ClusteringIterationStats, double, time_search)
DEFINE_GETTER(ClusteringIterationStats, double, imbalance_factor)
DEFINE_GETTER(ClusteringIterationStats, int, nsplit)

void faiss_ClusteringParameters_init(FaissClusteringParameters* params) {
    ClusteringParameters d;
    params->frozen_centroids = d.frozen_centroids;
    params->max_points_per_centroid = d.max_points_per_centroid;
    params->min_points_per_centroid = d.min_points_per_centroid;
    params->niter = d.niter;
    params->nredo = d.nredo;
    params->seed = d.seed;
    params->spherical = d.spherical;
    params->int_centroids = d.int_centroids;
    params->update_index = d.update_index;
    params->verbose = d.verbose;
    params->decode_block_size = d.decode_block_size;
}

// This conversion is required because the two types are not memory-compatible
inline ClusteringParameters from_faiss_c(
        const FaissClusteringParameters* params) {
    ClusteringParameters o;
    o.frozen_centroids = params->frozen_centroids;
    o.max_points_per_centroid = params->max_points_per_centroid;
    o.min_points_per_centroid = params->min_points_per_centroid;
    o.niter = params->niter;
    o.nredo = params->nredo;
    o.seed = params->seed;
    o.spherical = params->spherical;
    o.update_index = params->update_index;
    o.int_centroids = params->int_centroids;
    o.verbose = params->verbose;
    o.decode_block_size = params->decode_block_size;
    return o;
}

/// getter for centroids (size = k * d)
void faiss_Clustering_centroids(
        FaissClustering* clustering,
        float** centroids,
        size_t* size) {
    std::vector<float>& v =
            reinterpret_cast<Clustering*>(clustering)->centroids;
    if (centroids) {
        *centroids = v.data();
    }
    if (size) {
        *size = v.size();
    }
}

/// getter for iteration stats
void faiss_Clustering_iteration_stats(
        FaissClustering* clustering,
        FaissClusteringIterationStats** iteration_stats,
        size_t* size) {
    std::vector<ClusteringIterationStats>& v =
            reinterpret_cast<Clustering*>(clustering)->iteration_stats;
    if (iteration_stats) {
        *iteration_stats =
                reinterpret_cast<FaissClusteringIterationStats*>(v.data());
    }
    if (size) {
        *size = v.size();
    }
}

/// the only mandatory parameters are k and d
int faiss_Clustering_new(FaissClustering** p_clustering, int d, int k) {
    try {
        Clustering* c = new Clustering(d, k);
        *p_clustering = reinterpret_cast<FaissClustering*>(c);
        return 0;
    }
    CATCH_AND_HANDLE
}

int faiss_Clustering_new_with_params(
        FaissClustering** p_clustering,
        int d,
        int k,
        const FaissClusteringParameters* cp) {
    try {
        Clustering* c = new Clustering(d, k, from_faiss_c(cp));
        *p_clustering = reinterpret_cast<FaissClustering*>(c);
        return 0;
    }
    CATCH_AND_HANDLE
}

/// Index is used during the assignment stage
int faiss_Clustering_train(
        FaissClustering* clustering,
        idx_t n,
        const float* x,
        FaissIndex* index) {
    try {
        reinterpret_cast<Clustering*>(clustering)
                ->train(n, x, *reinterpret_cast<Index*>(index));
        return 0;
    }
    CATCH_AND_HANDLE
}

void faiss_Clustering_free(FaissClustering* clustering) {
    delete reinterpret_cast<Clustering*>(clustering);
}

int faiss_kmeans_clustering(
        size_t d,
        size_t n,
        size_t k,
        const float* x,
        float* centroids,
        float* q_error) {
    try {
        float out = faiss::kmeans_clustering(d, n, k, x, centroids);
        if (q_error) {
            *q_error = out;
        }
        return 0;
    }
    CATCH_AND_HANDLE
}
}
