/*
 * Decompiled with CFR 0.152.
 */
package edu.cmu.tetrad.cluster;

import cern.colt.matrix.DoubleMatrix1D;
import cern.colt.matrix.DoubleMatrix2D;
import cern.colt.matrix.impl.DenseDoubleMatrix1D;
import edu.cmu.tetrad.cluster.ClusterUtils;
import edu.cmu.tetrad.cluster.ClusteringAlgorithm;
import edu.cmu.tetrad.cluster.metrics.Dissimilarity;
import edu.cmu.tetrad.cluster.metrics.SquaredErrorLoss;
import edu.cmu.tetrad.util.NumberFormatUtil;
import edu.cmu.tetrad.util.RandomUtil;
import java.text.NumberFormat;
import java.util.ArrayList;
import java.util.Iterator;
import java.util.LinkedList;
import java.util.List;
import java.util.TreeSet;

public class KMedoids
implements ClusteringAlgorithm {
    private DoubleMatrix2D data;
    private DoubleMatrix2D centers;
    private int maxIterations = 50;
    private List<Integer> clusters;
    private int iterations;
    private int numCenters;
    private Dissimilarity metric = new SquaredErrorLoss();
    private boolean verbose = true;

    private KMedoids() {
    }

    public static KMedoids randomClusters(int numCenters) {
        KMedoids algorithm = new KMedoids();
        algorithm.numCenters = numCenters;
        return algorithm;
    }

    @Override
    public void cluster(DoubleMatrix2D data) {
        this.data = data;
        this.centers = this.pickCenters(this.numCenters, data);
        this.clusters = new ArrayList<Integer>();
        for (int i = 0; i < data.rows(); ++i) {
            this.clusters.add(-1);
        }
        boolean changed = true;
        this.iterations = 0;
        while (changed && (this.maxIterations == -1 || this.iterations < this.maxIterations)) {
            ++this.iterations;
            System.out.println("Iteration = " + this.iterations);
            int numChanged = this.reassignPoints();
            changed = numChanged > 0;
            this.moveCentersToMedoids();
            System.out.println("Cluster counts: " + this.countClusterSizes());
        }
    }

    @Override
    public List<List<Integer>> getClusters() {
        return ClusterUtils.convertClusterIndicesToLists(this.clusters);
    }

    @Override
    public DoubleMatrix2D getPrototypes() {
        return this.centers.copy();
    }

    public int getMaxIterations() {
        return this.maxIterations;
    }

    public void setMaxIterations(int maxIterations) {
        this.maxIterations = maxIterations;
    }

    public int getNumClusters() {
        return this.centers.rows();
    }

    public List<Integer> getCluster(int k) {
        ArrayList<Integer> cluster = new ArrayList<Integer>();
        for (int i = 0; i < this.clusters.size(); ++i) {
            if (this.clusters.get(i) != k) continue;
            cluster.add(i);
        }
        return cluster;
    }

    public int iterations() {
        return this.iterations;
    }

    public double squaredError(int k) {
        double squaredError = 0.0;
        for (int i = 0; i < this.data.rows(); ++i) {
            if (this.clusters.get(i) != k) continue;
            DoubleMatrix1D datum = this.data.viewRow(i);
            DoubleMatrix1D center = this.centers.viewRow(k);
            squaredError += this.metric.dissimilarity(datum, center);
        }
        return squaredError;
    }

    public double totalSquaredError() {
        double totalSquaredError = 0.0;
        for (int k = 0; k < this.centers.rows(); ++k) {
            totalSquaredError += this.squaredError(k);
        }
        return totalSquaredError;
    }

    public String toString() {
        NumberFormat n1 = NumberFormatUtil.getInstance().getNumberFormat();
        DoubleMatrix1D counts = this.countClusterSizes();
        double totalSquaredError = this.totalSquaredError();
        StringBuilder buf = new StringBuilder();
        buf.append("Cluster Result (").append(this.clusters.size()).append(" cases, ").append(this.centers.columns()).append(" feature(s), ").append(this.centers.rows()).append(" clusters)");
        for (int k = 0; k < this.centers.rows(); ++k) {
            buf.append("\n\tCluster #").append(k + 1).append(": n = ").append(counts.get(k));
            buf.append(" Squared Error = ").append(n1.format(this.squaredError(k)));
        }
        buf.append("\n\tTotal Squared Error = ").append(n1.format(totalSquaredError));
        return buf.toString();
    }

    private int reassignPoints() {
        int numChanged = 0;
        for (int i = 0; i < this.data.rows(); ++i) {
            DoubleMatrix1D datum = this.data.viewRow(i);
            double minDissimilarity = Double.POSITIVE_INFINITY;
            int cluster = -1;
            for (int k = 0; k < this.centers.rows(); ++k) {
                DoubleMatrix1D center = this.centers.viewRow(k);
                double dissimilarity = this.dissimilarity(datum, center);
                if (!(dissimilarity < minDissimilarity)) continue;
                minDissimilarity = dissimilarity;
                cluster = k;
            }
            if (cluster == this.clusters.get(i)) continue;
            this.clusters.set(i, cluster);
            ++numChanged;
        }
        System.out.println("Moved " + numChanged + " points.");
        return numChanged;
    }

    private void moveCentersToMedoids() {
        for (int k = 0; k < this.centers.rows(); ++k) {
            LinkedList<Integer> cluster = new LinkedList<Integer>();
            for (int i = 0; i < this.clusters.size(); ++i) {
                if (this.clusters.get(i) != k) continue;
                cluster.add(i);
            }
            if (cluster.isEmpty()) continue;
            double min = Double.POSITIVE_INFINITY;
            int i = -1;
            for (int j = 0; j < cluster.size(); ++j) {
                double totalDistance = this.totalDistance(j, cluster);
                if (!(totalDistance < min)) continue;
                min = totalDistance;
                i = j;
            }
            this.centers.viewRow(k).assign(this.data.viewRow((Integer)cluster.get(i)));
        }
    }

    private double totalDistance(int j, List<Integer> cluster) {
        double sum = 0.0;
        for (int i = 0; i < cluster.size(); ++i) {
            if (i == j) continue;
            sum += this.dissimilarity(this.data.viewRow(cluster.get(i)), this.data.viewRow(cluster.get(j)));
        }
        return sum;
    }

    private DoubleMatrix2D pickCenters(int numCenters, DoubleMatrix2D data) {
        TreeSet<Integer> indexSet = new TreeSet<Integer>();
        while (indexSet.size() < numCenters) {
            int candidate = RandomUtil.getInstance().nextInt(data.rows());
            if (indexSet.contains(candidate)) continue;
            indexSet.add(candidate);
        }
        int[] rows = new int[numCenters];
        int i = -1;
        Iterator i$ = indexSet.iterator();
        while (i$.hasNext()) {
            int row = (Integer)i$.next();
            rows[++i] = row;
        }
        int[] cols = new int[data.columns()];
        for (int j = 0; j < data.columns(); ++j) {
            cols[j] = j;
        }
        return data.viewSelection(rows, cols).copy();
    }

    private double dissimilarity(DoubleMatrix1D d1, DoubleMatrix1D d2) {
        return this.metric.dissimilarity(d1, d2);
    }

    private DoubleMatrix1D countClusterSizes() {
        DenseDoubleMatrix1D counts = new DenseDoubleMatrix1D(this.centers.rows());
        for (int cluster : this.clusters) {
            counts.set(cluster, counts.get(cluster) + 1.0);
        }
        return counts;
    }

    public boolean isVerbose() {
        return this.verbose;
    }

    @Override
    public void setVerbose(boolean verbose) {
        this.verbose = verbose;
    }
}

