/*
 * Decompiled with CFR 0.152.
 */
package edu.cmu.tetrad.cluster;

import cern.colt.matrix.DoubleMatrix1D;
import cern.colt.matrix.DoubleMatrix2D;
import cern.colt.matrix.impl.DenseDoubleMatrix1D;
import cern.colt.matrix.impl.DenseDoubleMatrix2D;
import edu.cmu.tetrad.cluster.ClusterUtils;
import edu.cmu.tetrad.cluster.ClusteringAlgorithm;
import edu.cmu.tetrad.cluster.metrics.Dissimilarity;
import edu.cmu.tetrad.cluster.metrics.SquaredErrorLoss;
import edu.cmu.tetrad.util.NumberFormatUtil;
import edu.cmu.tetrad.util.RandomUtil;
import java.text.NumberFormat;
import java.util.ArrayList;
import java.util.Iterator;
import java.util.List;
import java.util.TreeSet;

public class KMeans
implements ClusteringAlgorithm {
    public static final int RANDOM_POINTS = 0;
    public static final int RANDOM_CLUSTERS = 1;
    public static final int EXPLICIT_POINTS = 2;
    private DoubleMatrix2D data;
    private DoubleMatrix2D centers;
    private int maxIterations = 50;
    private List<Integer> clusters;
    private int iterations;
    private Dissimilarity metric = new SquaredErrorLoss();
    private int numCenters;
    private int initializationType = 0;
    private boolean verbose = true;

    private KMeans() {
    }

    public static KMeans randomPoints(int numCenters) {
        KMeans algorithm = new KMeans();
        algorithm.numCenters = numCenters;
        algorithm.initializationType = 0;
        return algorithm;
    }

    public static KMeans randomClusters(int numCenters) {
        KMeans algorithm = new KMeans();
        algorithm.numCenters = numCenters;
        algorithm.initializationType = 1;
        return algorithm;
    }

    public static KMeans explicitPoints(DoubleMatrix2D centers) {
        KMeans algorithm = new KMeans();
        algorithm.centers = centers;
        return algorithm;
    }

    @Override
    public void cluster(DoubleMatrix2D data) {
        int i;
        this.data = data;
        if (this.initializationType == 0) {
            this.centers = this.pickCenters(this.numCenters, data);
            this.clusters = new ArrayList<Integer>();
            for (i = 0; i < data.rows(); ++i) {
                this.clusters.add(-1);
            }
        } else if (this.initializationType == 1) {
            this.centers = new DenseDoubleMatrix2D(this.numCenters, data.columns());
            this.clusters = new ArrayList<Integer>();
            for (i = 0; i < data.rows(); ++i) {
                this.clusters.add(RandomUtil.getInstance().nextInt(this.centers.rows()));
            }
            this.moveCentersToMeans();
        } else if (this.initializationType == 2) {
            this.clusters = new ArrayList<Integer>();
            for (i = 0; i < data.rows(); ++i) {
                this.clusters.add(-1);
            }
        }
        boolean changed = true;
        this.iterations = 0;
        while (changed && (this.maxIterations == -1 || this.iterations < this.maxIterations)) {
            ++this.iterations;
            int numChanged = this.reassignPoints();
            changed = numChanged > 0;
            this.moveCentersToMeans();
        }
    }

    @Override
    public List<List<Integer>> getClusters() {
        return ClusterUtils.convertClusterIndicesToLists(this.clusters);
    }

    @Override
    public DoubleMatrix2D getPrototypes() {
        return this.centers.copy();
    }

    public int getMaxIterations() {
        return this.maxIterations;
    }

    public void setMaxIterations(int maxIterations) {
        this.maxIterations = maxIterations;
    }

    public int getNumClusters() {
        return this.centers.rows();
    }

    public List<Integer> getCluster(int k) {
        ArrayList<Integer> cluster = new ArrayList<Integer>();
        for (int i = 0; i < this.clusters.size(); ++i) {
            if (this.clusters.get(i) != k) continue;
            cluster.add(i);
        }
        return cluster;
    }

    public Dissimilarity getMetric() {
        return this.metric;
    }

    public int iterations() {
        return this.iterations;
    }

    public double squaredError(int k) {
        double squaredError = 0.0;
        for (int i = 0; i < this.data.rows(); ++i) {
            if (this.clusters.get(i) != k) continue;
            DoubleMatrix1D datum = this.data.viewRow(i);
            DoubleMatrix1D center = this.centers.viewRow(k);
            squaredError += this.metric.dissimilarity(datum, center);
        }
        return squaredError;
    }

    public double totalSquaredError() {
        double totalSquaredError = 0.0;
        for (int k = 0; k < this.centers.rows(); ++k) {
            totalSquaredError += this.squaredError(k);
        }
        return totalSquaredError;
    }

    public String toString() {
        NumberFormat n1 = NumberFormatUtil.getInstance().getNumberFormat();
        DoubleMatrix1D counts = this.countClusterSizes();
        double totalSquaredError = this.totalSquaredError();
        StringBuilder buf = new StringBuilder();
        buf.append("Cluster Result (").append(this.clusters.size()).append(" cases, ").append(this.centers.columns()).append(" feature(s), ").append(this.centers.rows()).append(" clusters)");
        for (int k = 0; k < this.centers.rows(); ++k) {
            buf.append("\n\tCluster #").append(k + 1).append(": n = ").append(counts.get(k));
            buf.append(" Squared Error = ").append(n1.format(this.squaredError(k)));
        }
        buf.append("\n\tTotal Squared Error = ").append(n1.format(totalSquaredError));
        return buf.toString();
    }

    private int reassignPoints() {
        int numChanged = 0;
        for (int i = 0; i < this.data.rows(); ++i) {
            DoubleMatrix1D datum = this.data.viewRow(i);
            double minDissimilarity = Double.POSITIVE_INFINITY;
            int cluster = -1;
            for (int k = 0; k < this.centers.rows(); ++k) {
                DoubleMatrix1D center = this.centers.viewRow(k);
                double dissimilarity = this.getMetric().dissimilarity(datum, center);
                if (!(dissimilarity < minDissimilarity)) continue;
                minDissimilarity = dissimilarity;
                cluster = k;
            }
            if (cluster == this.clusters.get(i)) continue;
            this.clusters.set(i, cluster);
            ++numChanged;
        }
        return numChanged;
    }

    private void moveCentersToMeans() {
        for (int k = 0; k < this.centers.rows(); ++k) {
            double[] sums = new double[this.centers.columns()];
            int count = 0;
            for (int i = 0; i < this.data.rows(); ++i) {
                if (this.clusters.get(i) != k) continue;
                for (int j = 0; j < this.data.columns(); ++j) {
                    int n = j;
                    sums[n] = sums[n] + this.data.get(i, j);
                }
                ++count;
            }
            if (count == 0) continue;
            for (int j = 0; j < this.centers.columns(); ++j) {
                this.centers.set(k, j, sums[j] / (double)count);
            }
        }
    }

    private DoubleMatrix2D pickCenters(int numCenters, DoubleMatrix2D data) {
        TreeSet<Integer> indexSet = new TreeSet<Integer>();
        while (indexSet.size() < numCenters) {
            int candidate = RandomUtil.getInstance().nextInt(data.rows());
            if (indexSet.contains(candidate)) continue;
            indexSet.add(candidate);
        }
        int[] rows = new int[numCenters];
        int i = -1;
        Iterator i$ = indexSet.iterator();
        while (i$.hasNext()) {
            int row = (Integer)i$.next();
            rows[++i] = row;
        }
        int[] cols = new int[data.columns()];
        for (int j = 0; j < data.columns(); ++j) {
            cols[j] = j;
        }
        return data.viewSelection(rows, cols).copy();
    }

    private DoubleMatrix1D countClusterSizes() {
        DenseDoubleMatrix1D counts = new DenseDoubleMatrix1D(this.centers.rows());
        for (int cluster : this.clusters) {
            if (cluster == -1) continue;
            counts.set(cluster, counts.get(cluster) + 1.0);
        }
        return counts;
    }

    public boolean isVerbose() {
        return this.verbose;
    }

    @Override
    public void setVerbose(boolean verbose) {
        this.verbose = verbose;
    }
}

