/*
 * Decompiled with CFR 0.152.
 */
package edu.cmu.tetrad.cluster;

import cern.colt.matrix.DoubleMatrix1D;
import cern.colt.matrix.DoubleMatrix2D;
import cern.colt.matrix.impl.DenseDoubleMatrix1D;
import cern.colt.matrix.impl.DenseDoubleMatrix2D;
import edu.cmu.tetrad.cluster.ClusterUtils;
import edu.cmu.tetrad.cluster.ClusteringAlgorithm;
import edu.cmu.tetrad.cluster.metrics.AbsoluteErrorLoss;
import edu.cmu.tetrad.cluster.metrics.Dissimilarity;
import edu.cmu.tetrad.cluster.metrics.SquaredErrorLoss;
import edu.cmu.tetrad.util.RandomUtil;
import java.util.ArrayList;
import java.util.Collections;
import java.util.Iterator;
import java.util.List;
import java.util.TreeSet;

public class Ng
implements ClusteringAlgorithm {
    private int numUnits;
    private DoubleMatrix2D data;
    private List<Unit> units;
    private double epsilonI = 0.05;
    private double epsilonF = 6.0E-4;
    private double lambdaI = 0.5;
    private double lambdaF = 0.1;
    private int tMax = 10000;
    private Dissimilarity metric = new SquaredErrorLoss();
    private boolean verbose = true;
    private List<Integer> clusters;

    private Ng() {
    }

    public static Ng randomUnitsFromData(int numUnits) {
        Ng algorithm = new Ng();
        algorithm.numUnits = numUnits;
        return algorithm;
    }

    @Override
    public void cluster(DoubleMatrix2D data) {
        this.units = this.pickUnits(this.numUnits, data);
        this.data = data;
        DenseDoubleMatrix1D delta = new DenseDoubleMatrix1D(data.columns());
        int t = 0;
        DenseDoubleMatrix1D e = new DenseDoubleMatrix1D(data.columns());
        ((DoubleMatrix1D)e).assign(1.0);
        Unit d2 = new Unit(e);
        int iteration = -1;
        while (this.length(d2) * (double)data.columns() > 1.0E-8) {
            System.out.println("Length = " + this.length(d2) + " Iteration = " + ++iteration);
            d2 = new Unit(new DenseDoubleMatrix1D(data.columns()));
            if (this.isVerbose() && (t + 1) % 1000 == 0) {
                System.out.println("Iteration " + (t + 1));
            }
            DoubleMatrix1D signal = this.randomPoint();
            for (Unit unit : this.units) {
                unit.calculateDistanceToSignal(signal);
            }
            Collections.sort(this.units);
            for (int i = 0; i < this.units.size(); ++i) {
                Unit unit;
                unit = this.units.get(i);
                for (int j = 0; j < signal.size(); ++j) {
                    delta.set(j, signal.get(j) - unit.getVector().get(j));
                }
                double multiplier = this.epsilon(t) * this.h(i + 1, this.lambda(t));
                for (int j = 0; j < signal.size(); ++j) {
                    delta.set(j, multiplier * delta.get(j));
                }
                unit.moveVector(delta);
                d2.moveVector(delta);
            }
            ++t;
        }
        ArrayList<Integer> clusters = new ArrayList<Integer>();
        clusters.ensureCapacity(data.rows());
        AbsoluteErrorLoss metric = new AbsoluteErrorLoss();
        for (int i = 0; i < data.rows(); ++i) {
            double distance = Double.POSITIVE_INFINITY;
            int cluster = -1;
            for (int k = 0; k < this.numUnits; ++k) {
                double d = metric.dissimilarity(this.units.get(k).getVector(), data.viewRow(i));
                if (!(d < distance)) continue;
                distance = d;
                cluster = k;
            }
            clusters.add(cluster);
        }
        this.clusters = clusters;
    }

    private double length(Unit d2) {
        DoubleMatrix1D vector = d2.getVector();
        double sum = 0.0;
        for (int i = 0; i < vector.size(); ++i) {
            double v = vector.get(i);
            sum += v * v;
        }
        return Math.sqrt(sum);
    }

    @Override
    public List<List<Integer>> getClusters() {
        return ClusterUtils.convertClusterIndicesToLists(this.clusters);
    }

    @Override
    public DoubleMatrix2D getPrototypes() {
        return null;
    }

    public double getEpsilonI() {
        return this.epsilonI;
    }

    public void setEpsilonI(double epsilonI) {
        this.epsilonI = epsilonI;
    }

    public double getEpsilonF() {
        return this.epsilonF;
    }

    public void setEpsilonF(double epsilonF) {
        this.epsilonF = epsilonF;
    }

    public double getLambdaI() {
        return this.lambdaI;
    }

    public void setLambdaI(double lambdaI) {
        this.lambdaI = lambdaI;
    }

    public double getLambdaF() {
        return this.lambdaF;
    }

    public void setLambdaF(double lambdaF) {
        this.lambdaF = lambdaF;
    }

    public int getTMax() {
        return this.tMax;
    }

    public void setTMax(int tMax) {
        this.tMax = tMax;
    }

    public DoubleMatrix2D getUnitMatrix() {
        DenseDoubleMatrix2D matrix = new DenseDoubleMatrix2D(this.units.size(), this.units.get(0).getVector().size());
        for (int i = 0; i < matrix.rows(); ++i) {
            for (int j = 0; j < matrix.columns(); ++j) {
                matrix.set(i, j, this.units.get(i).getVector().get(j));
            }
        }
        return matrix;
    }

    public String toString() {
        StringBuilder buf = new StringBuilder();
        for (int i = 0; i < this.units.size(); ++i) {
            buf.append("\nUnit " + i + ": " + this.units.get(i));
        }
        return buf.toString();
    }

    private double h(int k, double lambda) {
        return Math.exp((double)(-k) / lambda);
    }

    private double lambda(int t) {
        return this.lambdaI * Math.pow(this.lambdaF / this.lambdaI, (double)t / (double)this.tMax);
    }

    private double epsilon(int t) {
        double epsilon = this.epsilonI * Math.pow(this.epsilonF / this.epsilonI, (double)t / (double)this.tMax);
        return epsilon;
    }

    private DoubleMatrix1D randomPoint() {
        int i = RandomUtil.getInstance().nextInt(this.data.rows());
        return this.data.viewRow(i);
    }

    private List<Unit> pickUnits(int numUnits, DoubleMatrix2D data) {
        TreeSet<Integer> indexSet = new TreeSet<Integer>();
        while (indexSet.size() < numUnits) {
            int candidate = RandomUtil.getInstance().nextInt(data.rows());
            if (indexSet.contains(candidate)) continue;
            indexSet.add(candidate);
        }
        ArrayList<Unit> units = new ArrayList<Unit>();
        Iterator i$ = indexSet.iterator();
        while (i$.hasNext()) {
            int i = (Integer)i$.next();
            units.add(new Unit(data.viewRow(i).copy()));
        }
        return units;
    }

    public Dissimilarity getMetric() {
        return this.metric;
    }

    public void setMetric(Dissimilarity metric) {
        this.metric = metric;
    }

    public boolean isVerbose() {
        return this.verbose;
    }

    @Override
    public void setVerbose(boolean verbose) {
        this.verbose = verbose;
    }

    private class Unit
    implements Comparable {
        private DoubleMatrix1D vector;
        private double distance = Double.NaN;

        public Unit(DoubleMatrix1D unitVector) {
            this.vector = unitVector;
        }

        public void moveVector(DoubleMatrix1D delta) {
            for (int i = 0; i < this.vector.size(); ++i) {
                this.vector.set(i, this.vector.get(i) + delta.get(i));
            }
        }

        public void calculateDistanceToSignal(DoubleMatrix1D signal) {
            this.distance = this.distance(this.vector, signal);
        }

        private double distance(DoubleMatrix1D vector, DoubleMatrix1D signal) {
            return Ng.this.getMetric().dissimilarity(vector, signal);
        }

        public int compareTo(Object o) {
            Unit other = (Unit)o;
            return (int)Math.signum(this.distance - other.distance);
        }

        public DoubleMatrix1D getVector() {
            return this.vector;
        }

        public String toString() {
            return "" + this.vector;
        }
    }
}

