# Copyright (c) Meta Platforms, Inc. and affiliates.
#
# This source code is licensed under the MIT license found in the
# LICENSE file in the root directory of this source tree.

import os
import numpy as np
import faiss
import argparse
import datasets
from datasets import sanitize

######################################################
# Command-line parsing
######################################################

parser = argparse.ArgumentParser()


def aa(*args, **kwargs):
    group.add_argument(*args, **kwargs)


group = parser.add_argument_group('dataset options')

aa('--db', default='deep1M', help='dataset')
aa('--nt', default=65536, type=int)
aa('--nb', default=100000, type=int)
aa('--nt_sample', default=0, type=int)

group = parser.add_argument_group('kmeans options')
aa('--k', default=256, type=int)
aa('--seed', default=12345, type=int)
aa('--pcadim', default=-1, type=int, help='PCA to this dimension')
aa('--niter', default=25, type=int)
aa('--eval_freq', default=100, type=int)


args = parser.parse_args()

print("args:", args)

os.system('echo -n "nb processors "; '
          'cat /proc/cpuinfo | grep ^processor | wc -l; '
          'cat /proc/cpuinfo | grep ^"model name" | tail -1')

ngpu = faiss.get_num_gpus()
print("nb GPUs:", ngpu)

######################################################
# Load dataset
######################################################

xt, xb, xq, gt = datasets.load_data(dataset=args.db)


if args.nt_sample == 0:
    xt_pca = xt[args.nt:args.nt + 10000]
    xt = xt[:args.nt]
else:
    xt_pca = xt[args.nt_sample:args.nt_sample + 10000]
    rs = np.random.RandomState(args.seed)
    idx = rs.choice(args.nt_sample, size=args.nt, replace=False)
    xt = xt[idx]

xb = xb[:args.nb]

d = xb.shape[1]

if args.pcadim != -1:
    print("training PCA: %d -> %d" % (d, args.pcadim))
    pca = faiss.PCAMatrix(d, args.pcadim)
    pca.train(sanitize(xt_pca))
    xt = pca.apply_py(sanitize(xt))
    xb = pca.apply_py(sanitize(xb))
    d = xb.shape[1]


######################################################
# Run clustering
######################################################


index = faiss.IndexFlatL2(d)

if ngpu > 0:
    print("moving index to GPU")
    index = faiss.index_cpu_to_all_gpus(index)


clustering = faiss.Clustering(d, args.k)

clustering.verbose = True
clustering.seed = args.seed
clustering.max_points_per_centroid = 10**6
clustering.min_points_per_centroid = 1

centroids = None

for iter0 in range(0, args.niter, args.eval_freq):
    iter1 = min(args.niter, iter0 + args.eval_freq)
    clustering.niter = iter1 - iter0

    if iter0 > 0:
        faiss.copy_array_to_vector(centroids.ravel(), clustering.centroids)

    clustering.train(sanitize(xt), index)
    index.reset()
    centroids = faiss.vector_to_array(clustering.centroids).reshape(args.k, d)
    index.add(centroids)

    _, I = index.search(sanitize(xb), 1)

    error = ((xb - centroids[I.ravel()]) ** 2).sum()

    print("iter1=%d quantization error on test: %.4f" % (iter1, error))
