/**
File:		MachineLearning/Util/FgKMeans.cpp

Author:		
Email:		
Site:       

Copyright (c) 2017 . All rights reserved.
*/

#include <NeMachineLearningPCH.h>
#include <MachineLearning/FgKMeans.h>

namespace NeuralEngine::MachineLearning
{
	template class KMeans<float>;
	template class KMeans<double>;

	template<typename Scalar>
	void KMeans<Scalar>::Compute(af::array& means, af::array & clusters, const af::array& in, int k, int iter)
	{
		std::cout << "KMeans running on " << in.dims(0) << " points with " << k << " cluster." << std::endl;

		af::dtype m_dType = CommonUtil<Scalar>::CheckDType();

		unsigned n = in.dims(0); // Num features
		unsigned d = in.dims(1); // feature length

		af::array data = in * 0;
		// re-center and scale down data to [0, 1]
		af::array minimum = tile(min(in, 0), n);
		af::array maximum = tile(max(in, 0), n);

		for (int ii = 0; ii < d; ii++)
			data(af::span, ii) = (in(af::span, ii) - minimum(af::span, ii)) / (maximum(af::span, ii) - minimum(af::span, ii));

		data(af::isNaN(data)) = 0.0;

		// Initial guess of means
		std::cout << "\t--> Initial guess..." << std::endl;

		af::setSeed(time(NULL));
		af::array idx = af::round(af::randu(1, k, m_dType) * (n - 1));
		means = data(idx, af::span);
		means = af::moddims(means.T(), 1, d, k);
		
		af::array curr_clusters = af::constant(0, data.dims(0), m_dType) - 1;
		af::array prev_clusters(m_dType);

		// Stop updating after specified number of iterations
		std::cout << "\t--> Running iterations..." << std::endl;
		for (int i = 0; i < iter; i++) 
		{
			std::cout << "\t\t--> iter " << i << std::endl;
			// Store previous cluster ids
			prev_clusters = curr_clusters;
			// Get cluster ids for current means
			curr_clusters = Clusterize(data, means);
			// Break early if clusters not changing
			unsigned num_changed = af::count<unsigned>(prev_clusters != curr_clusters);
			if (num_changed < (n / 1000) + 1) break;
			// Update current means for new clusters
			means = NewMeans(data, curr_clusters, k);
		}

		af::array modMean = af::constant(0, k, d, (m_dType));
		for (int ii = 0; ii < k; ii++)
			modMean(ii, af::span) = means(af::span, af::span, ii);

		maximum = maximum(af::seq(0, k - 1), af::span);
		minimum = minimum(af::seq(0, k - 1), af::span);

		// Scale up means
		for (int ii = 0; ii < d; ii++)
			modMean(af::span, ii) = (maximum(af::span, ii) - minimum(af::span, ii)) * modMean(af::span, ii) + minimum(af::span, ii);
		
		clusters = prev_clusters;
		means = modMean;
		//af_print(means);
		std::cout << "Done.\n" << std::endl;
	}

	template<typename Scalar>
	af::array KMeans<Scalar>::Clusterize(const af::array& data, const af::array& means)
	{
		af::dtype m_dType = CommonUtil<Scalar>::CheckDType();

		// Get euclidean distance
		af::array dists = Distance(data, means);
		//af_print(dists);

		// get the locations of minimum distance
		af::array idx(m_dType), val(m_dType);
		af::min(val, idx, dists, 2);

		// Return cluster IDs
		return idx;
	}

	template<typename Scalar>
	af::array KMeans<Scalar>::NewMeans(af::array data, af::array clusters, int k)
	{
		af::dtype m_dType = CommonUtil<Scalar>::CheckDType();

		int d = data.dims(1);
		int n = data.dims(0);
		af::array means = af::constant(0, 1, d, k, (m_dType));
		af::array dataSum(m_dType);
		//af_print(clusters);
		for (int ii = 0; ii < k; ii++)//gfor(seq ii, k)
		{
			dataSum = sum(data(where(clusters == ii), af::span), 0);
			if (dataSum.isempty())
				means(af::span, af::span, ii) = data(af::round(af::randu(1, 1) * (n - 1)), af::span);
			else
				means(af::span, af::span, ii) = dataSum / tile(sum(clusters == ii), 1, d);
		}
		return means;
	}

	template<typename Scalar>
	af::array KMeans<Scalar>::Distance(const af::array& data, const af::array& means)
	{
		af::dtype m_dType = CommonUtil<Scalar>::CheckDType();

		int n = data.dims(0); // Number of features
		int d = data.dims(1); // dimension
		int k = means.dims(2); // Number of means
		af::array data2 = af::tile(data, 1, 1, k);
		af::array means2 = af::tile(means, n, 1, 1);
		af::array dist = af::constant(0, n, 1, k, (m_dType));

		// Currently using euclidean distance
		// Can be replaced with other distance measures
		gfor(af::seq ii, k) {
			dist(af::span, af::span, ii) = CommonUtil<Scalar>::Euclidean(data2(af::span, af::span, ii), means2(af::span, af::span, ii));
		}
		return dist;
	}
}