#! /usr/bin/env python2

# Copyright (c) Facebook, Inc. and its affiliates.
#
# This source code is licensed under the MIT license found in the
# LICENSE file in the root directory of this source tree.

from __future__ import print_function

import sys
import time

import faiss
import numpy as np
import argparse
from sklearn.preprocessing import normalize


def read_feat(fname):
    list_fname = fname.split(",")
    list_npy = []
    if len(list_fname) == 1:
        return normalize(np.load(fname))
    else:
        for fname in list_fname:
            list_npy.append(np.load(fname))
        return normalize(np.concatenate(list_npy, axis=0))


if __name__ == "__main__":
    parser = argparse.ArgumentParser()
    parser.add_argument("input", help="Numpy, Input Features")
    parser.add_argument("k", type=int, help="Int, Number of classes")
    parser.add_argument("output", help="class center")
    args = parser.parse_args()

    k = int(args.k)
    ngpu = 8

    x = read_feat(args.input)
    x = x.reshape(x.shape[0], -1).astype('float32')

    def train_kmeans(x, k, ngpu):
        "Runs kmeans on one or several GPUs"
        d = x.shape[1]
        clus = faiss.Clustering(d, k)
        clus.verbose = True
        clus.niter = 20

        # otherwise the kmeans implementation sub-samples the training set
        clus.max_points_per_centroid = 10000000

        res = [faiss.StandardGpuResources() for i in range(ngpu)]

        flat_config = []
        for i in range(ngpu):
            cfg = faiss.GpuIndexFlatConfig()
            cfg.useFloat16 = False
            cfg.device = i
            flat_config.append(cfg)

        if ngpu == 1:
            index = faiss.GpuIndexFlatL2(res[0], d, flat_config[0])
        else:
            indexes = [faiss.GpuIndexFlatL2(res[i], d, flat_config[i])
                    for i in range(ngpu)]
            index = faiss.IndexReplicas()
            for sub_index in indexes:
                index.addIndex(sub_index)

        # perform the training
        clus.train(x, index)
        centroids = faiss.vector_float_to_array(clus.centroids)

        # obj = faiss.vector_float_to_array(clus.obj)
        # print("final objective: %.4g" % obj[-1])

        return centroids.reshape(k, d)

    print("run")
    t0 = time.time()
    centroids = train_kmeans(x, k, ngpu)
    t1 = time.time()

    print("total runtime: %.3f s" % (t1 - t0))

    np.save(args.output, centroids)
