/*
 * Decompiled with CFR 0.152.
 */
package edu.cmu.tetrad.cluster;

import edu.cmu.tetrad.cluster.ClusteringAlgorithm;
import edu.cmu.tetrad.cluster.metrics.Dissimilarity;
import edu.cmu.tetrad.cluster.metrics.SquaredErrorLoss;
import edu.cmu.tetrad.util.Matrix;
import edu.cmu.tetrad.util.NumberFormatUtil;
import edu.cmu.tetrad.util.RandomUtil;
import edu.cmu.tetrad.util.Vector;
import java.text.NumberFormat;
import java.util.ArrayList;
import java.util.Iterator;
import java.util.LinkedList;
import java.util.List;
import java.util.TreeSet;

public class KMeans
implements ClusteringAlgorithm {
    private static final int RANDOM_POINTS = 0;
    private static final int RANDOM_CLUSTERS = 1;
    private static final int EXPLICIT_POINTS = 2;
    private final Dissimilarity metric = new SquaredErrorLoss();
    private Matrix data;
    private Matrix centers;
    private int maxIterations = 50;
    private List<Integer> clusters;
    private int iterations;
    private int numCenters;
    private int initializationType = 0;
    private boolean verbose;

    private KMeans() {
    }

    public static KMeans randomPoints(int numCenters) {
        KMeans algorithm = new KMeans();
        algorithm.numCenters = numCenters;
        algorithm.initializationType = 0;
        return algorithm;
    }

    public static KMeans randomClusters(int numCenters) {
        KMeans algorithm = new KMeans();
        algorithm.numCenters = numCenters;
        algorithm.initializationType = 1;
        return algorithm;
    }

    private static List<List<Integer>> convertClusterIndicesToLists(List<Integer> clusterIndices) {
        int i;
        int max = 0;
        for (Integer clusterIndice : clusterIndices) {
            if (clusterIndice <= max) continue;
            max = clusterIndice;
        }
        ArrayList<List<Integer>> clusters = new ArrayList<List<Integer>>();
        for (i = 0; i <= max; ++i) {
            clusters.add(new LinkedList());
        }
        for (i = 0; i < clusterIndices.size(); ++i) {
            Integer index = clusterIndices.get(i);
            if (index == -1) continue;
            ((List)clusters.get(index)).add(i);
        }
        return clusters;
    }

    @Override
    public void cluster(Matrix data) {
        int i;
        this.data = data;
        if (this.initializationType == 0) {
            this.centers = this.pickCenters(this.numCenters, data);
            this.clusters = new ArrayList<Integer>();
            for (i = 0; i < data.getNumRows(); ++i) {
                this.clusters.add(-1);
            }
        } else if (this.initializationType == 1) {
            this.centers = new Matrix(this.numCenters, data.getNumColumns());
            this.clusters = new ArrayList<Integer>();
            for (i = 0; i < data.getNumRows(); ++i) {
                this.clusters.add(RandomUtil.getInstance().nextInt(this.centers.getNumRows()));
            }
            this.moveCentersToMeans();
        } else if (this.initializationType == 2) {
            this.clusters = new ArrayList<Integer>();
            for (i = 0; i < data.getNumRows(); ++i) {
                this.clusters.add(-1);
            }
        }
        boolean changed = true;
        this.iterations = 0;
        while (changed && (this.maxIterations == -1 || this.iterations < this.maxIterations)) {
            ++this.iterations;
            int numChanged = this.reassignPoints();
            changed = numChanged > 0;
            this.moveCentersToMeans();
        }
    }

    @Override
    public List<List<Integer>> getClusters() {
        return KMeans.convertClusterIndicesToLists(this.clusters);
    }

    public int getMaxIterations() {
        return this.maxIterations;
    }

    public void setMaxIterations(int maxIterations) {
        this.maxIterations = maxIterations;
    }

    public int getNumClusters() {
        return this.centers.getNumRows();
    }

    public List<Integer> getCluster(int k) {
        ArrayList<Integer> cluster = new ArrayList<Integer>();
        for (int i = 0; i < this.clusters.size(); ++i) {
            if (this.clusters.get(i) != k) continue;
            cluster.add(i);
        }
        return cluster;
    }

    private Dissimilarity getMetric() {
        return this.metric;
    }

    public int iterations() {
        return this.iterations;
    }

    private double squaredError(int k) {
        double squaredError = 0.0;
        for (int i = 0; i < this.data.getNumRows(); ++i) {
            if (this.clusters.get(i) != k) continue;
            Vector datum = this.data.getRow(i);
            Vector center = this.centers.getRow(k);
            squaredError += this.metric.dissimilarity(datum, center);
        }
        return squaredError;
    }

    private double totalSquaredError() {
        double totalSquaredError = 0.0;
        for (int k = 0; k < this.centers.getNumRows(); ++k) {
            totalSquaredError += this.squaredError(k);
        }
        return totalSquaredError;
    }

    public String toString() {
        NumberFormat n1 = NumberFormatUtil.getInstance().getNumberFormat();
        Vector counts = this.countClusterSizes();
        double totalSquaredError = this.totalSquaredError();
        StringBuilder buf = new StringBuilder();
        buf.append("Cluster Result (").append(this.clusters.size()).append(" cases, ").append(this.centers.getNumColumns()).append(" feature(s), ").append(this.centers.getNumRows()).append(" clusters)");
        for (int k = 0; k < this.centers.getNumRows(); ++k) {
            buf.append("\n\tCluster #").append(k + 1).append(": n = ").append(counts.get(k));
            buf.append(" Squared Error = ").append(n1.format(this.squaredError(k)));
        }
        buf.append("\n\tTotal Squared Error = ").append(n1.format(totalSquaredError));
        return buf.toString();
    }

    private int reassignPoints() {
        int numChanged = 0;
        for (int i = 0; i < this.data.getNumRows(); ++i) {
            Vector datum = this.data.getRow(i);
            double minDissimilarity = Double.POSITIVE_INFINITY;
            int cluster = -1;
            for (int k = 0; k < this.centers.getNumRows(); ++k) {
                Vector center = this.centers.getRow(k);
                double dissimilarity = this.getMetric().dissimilarity(datum, center);
                if (!(dissimilarity < minDissimilarity)) continue;
                minDissimilarity = dissimilarity;
                cluster = k;
            }
            if (cluster == this.clusters.get(i)) continue;
            this.clusters.set(i, cluster);
            ++numChanged;
        }
        return numChanged;
    }

    private void moveCentersToMeans() {
        for (int k = 0; k < this.centers.getNumRows(); ++k) {
            double[] sums = new double[this.centers.getNumColumns()];
            int count = 0;
            for (int i = 0; i < this.data.getNumRows(); ++i) {
                if (this.clusters.get(i) != k) continue;
                for (int j = 0; j < this.data.getNumColumns(); ++j) {
                    int n = j;
                    sums[n] = sums[n] + this.data.get(i, j);
                }
                ++count;
            }
            if (count == 0) continue;
            for (int j = 0; j < this.centers.getNumColumns(); ++j) {
                this.centers.set(k, j, sums[j] / (double)count);
            }
        }
    }

    private Matrix pickCenters(int numCenters, Matrix data) {
        TreeSet<Integer> indexSet = new TreeSet<Integer>();
        while (indexSet.size() < numCenters) {
            int candidate = RandomUtil.getInstance().nextInt(data.getNumRows());
            indexSet.add(candidate);
        }
        int[] rows = new int[numCenters];
        int i = -1;
        Iterator iterator = indexSet.iterator();
        while (iterator.hasNext()) {
            int row = (Integer)iterator.next();
            rows[++i] = row;
        }
        int[] cols = new int[data.getNumColumns()];
        for (int j = 0; j < data.getNumColumns(); ++j) {
            cols[j] = j;
        }
        return data.getSelection(rows, cols).copy();
    }

    private Vector countClusterSizes() {
        Vector counts = new Vector(this.centers.getNumRows());
        for (int cluster : this.clusters) {
            if (cluster == -1) continue;
            counts.set(cluster, counts.get(cluster) + 1.0);
        }
        return counts;
    }

    public boolean isVerbose() {
        return this.verbose;
    }

    @Override
    public void setVerbose(boolean verbose) {
        this.verbose = verbose;
    }
}

