#! /usr/bin/env python2
# Copyright (c) Meta Platforms, Inc. and affiliates.
#
# This source code is licensed under the MIT license found in the
# LICENSE file in the root directory of this source tree.

from __future__ import print_function
import numpy as np
import time
import faiss
import sys


# Get command-line arguments

k = int(sys.argv[1])
ngpu = int(sys.argv[2])

# Load Leon's file format

def load_mnist(fname):
    print("load", fname)
    f = open(fname)

    header = np.fromfile(f, dtype='int8', count=4*4)
    header = header.reshape(4, 4)[:, ::-1].copy().view('int32')
    print(header)
    nim, xd, yd = [int(x) for x in header[1:]]

    data = np.fromfile(f, count=nim * xd * yd,
                       dtype='uint8')

    print(data.shape, nim, xd, yd)
    data = data.reshape(nim, xd, yd)
    return data

basedir = "/path/to/mnist/data"

x = load_mnist(basedir + 'mnist8m/mnist8m-patterns-idx3-ubyte')

print("reshape")

x = x.reshape(x.shape[0], -1).astype('float32')


def train_kmeans(x, k, ngpu):
    "Runs kmeans on one or several GPUs"
    d = x.shape[1]
    clus = faiss.Clustering(d, k)
    clus.verbose = True
    clus.niter = 20

    # otherwise the kmeans implementation sub-samples the training set
    clus.max_points_per_centroid = 10000000

    res = [faiss.StandardGpuResources() for i in range(ngpu)]

    flat_config = []
    for i in range(ngpu):
        cfg = faiss.GpuIndexFlatConfig()
        cfg.useFloat16 = False
        cfg.device = i
        flat_config.append(cfg)

    if ngpu == 1:
        index = faiss.GpuIndexFlatL2(res[0], d, flat_config[0])
    else:
        indexes = [faiss.GpuIndexFlatL2(res[i], d, flat_config[i])
                   for i in range(ngpu)]
        index = faiss.IndexReplicas()
        for sub_index in indexes:
            index.addIndex(sub_index)

    # perform the training
    clus.train(x, index)
    centroids = faiss.vector_float_to_array(clus.centroids)

    obj = faiss.vector_float_to_array(clus.obj)
    print("final objective: %.4g" % obj[-1])

    return centroids.reshape(k, d)

print("run")
t0 = time.time()
train_kmeans(x, k, ngpu)
t1 = time.time()

print("total runtime: %.3f s" % (t1 - t0))
