/*
 * Decompiled with CFR 0.152.
 */
package edu.cmu.tetrad.cluster;

import cern.colt.matrix.DoubleMatrix1D;
import cern.colt.matrix.DoubleMatrix2D;
import cern.colt.matrix.impl.DenseDoubleMatrix1D;
import cern.colt.matrix.impl.DenseDoubleMatrix2D;
import edu.cmu.tetrad.cluster.ClusteringAlgorithm;
import edu.cmu.tetrad.cluster.metrics.Dissimilarity;
import edu.cmu.tetrad.cluster.metrics.SquaredErrorLoss;
import edu.cmu.tetrad.graph.Edge;
import edu.cmu.tetrad.graph.EdgeListGraph;
import edu.cmu.tetrad.graph.Edges;
import edu.cmu.tetrad.graph.GraphNode;
import edu.cmu.tetrad.graph.GraphUtils;
import edu.cmu.tetrad.graph.Node;
import edu.cmu.tetrad.util.RandomUtil;
import java.util.ArrayList;
import java.util.Collections;
import java.util.HashMap;
import java.util.LinkedList;
import java.util.List;

public class Gng
implements ClusteringAlgorithm {
    private DoubleMatrix2D data;
    private List<Unit> units;
    private EdgeListGraph graph;
    private int ageMax = 80;
    private int lambda = 2000;
    private double alpha = 0.5;
    private double beta = 5.0E-4;
    private double epsilonB = 0.1;
    private double epsilonN = 0.01;
    private int maxUnits = 50;
    private boolean verbose = true;
    private DenseDoubleMatrix1D delta;
    int unitIndex = 0;
    private long numSignalsGenerated = 0L;
    private HashMap<Unit, Node> unitsToNodes;
    private HashMap<Unit, Double> unitsToErrors;
    private HashMap<Node, Unit> nodesToUnits;
    private HashMap<Edge, Integer> edgesToAges;
    private Dissimilarity metric = new SquaredErrorLoss();
    private int numUnits = 0;
    private List<List<Integer>> clusters;

    private Gng() {
    }

    public static Gng init() {
        return new Gng();
    }

    @Override
    public void cluster(DoubleMatrix2D data) {
        this.data = data;
        this.delta = new DenseDoubleMatrix1D(data.columns());
        this.units = new ArrayList<Unit>();
        this.graph = new EdgeListGraph();
        this.unitsToNodes = new HashMap();
        this.nodesToUnits = new HashMap();
        this.unitsToNodes = new HashMap();
        this.edgesToAges = new HashMap();
        this.unitsToErrors = new HashMap();
        Node node1 = this.addUnit(this.nextSignal());
        Node node2 = this.addUnit(this.nextSignal());
        Edge initialEdge = Edges.undirectedEdge(node1, node2);
        this.graph.addEdge(initialEdge);
        this.edgesToAges.put(initialEdge, 0);
        this.print("Adding initial edge: " + initialEdge + " age " + this.edgesToAges.get(initialEdge));
        while (!this.shouldStop()) {
            DoubleMatrix1D signal = this.nextSignal();
            this.calculateDistancesToSignal(signal);
            Collections.sort(this.units);
            Unit s1 = this.units.get(0);
            Unit s2 = this.units.get(1);
            Node s1Node = this.unitsToNodes.get(s1);
            Node s2Node = this.unitsToNodes.get(s2);
            List<Edge> s1Edges = this.graph.getEdges(s1Node);
            for (Edge edge : s1Edges) {
                int age = this.edgesToAges.get(edge);
                this.edgesToAges.put(edge, age + 1);
            }
            this.incrementError(s1, signal);
            this.moveUnit(s1, signal, this.getEpsilonB());
            for (Node node : this.neighbors(s1)) {
                Unit unit = this.nodesToUnits.get(node);
                this.moveUnit(unit, signal, this.getEpsilonN());
            }
            Edge s1S2Edge = this.graph.getEdge(s1Node, s2Node);
            if (s1S2Edge == null) {
                s1S2Edge = Edges.undirectedEdge(s1Node, s2Node);
                this.graph.addEdge(s1S2Edge);
                this.print("Adding edge between s1 and s2: " + s1S2Edge + " age " + this.edgesToAges.get(s1S2Edge));
            }
            this.edgesToAges.put(s1S2Edge, 0);
            LinkedList<Edge> edgesToRemove = new LinkedList<Edge>();
            for (Edge edge : this.edgesToAges.keySet()) {
                if (this.edgesToAges.get(edge) <= this.getAgeMax()) continue;
                edgesToRemove.add(edge);
            }
            for (Edge edge : edgesToRemove) {
                this.print("Removing old edge: " + edge + " num nodes = " + this.graph.getNumNodes() + " num edges = " + this.graph.getNumEdges() + " age " + this.edgesToAges.get(edge));
                this.removeEdge(edge);
            }
            for (Node node : this.graph.getNodes()) {
                if (!this.graph.getEdges(node).isEmpty()) continue;
                this.removeUnit(this.nodesToUnits.get(node));
            }
            if (this.numSignalsGenerated % (long)this.getLambda() == 0L) {
                Unit q = this.unitWithMaximumError();
                Unit f = this.neighborWithMaximumError(q);
                DenseDoubleMatrix1D rVector = new DenseDoubleMatrix1D(q.getVector().size());
                for (int j = 0; j < rVector.size(); ++j) {
                    rVector.set(j, (q.getVector().get(j) + f.getVector().get(j)) / 2.0);
                }
                Node rNode = this.addUnit(rVector);
                Node qNode = this.unitsToNodes.get(q);
                Edge qrEdge = Edges.undirectedEdge(qNode, rNode);
                this.graph.addEdge(qrEdge);
                this.edgesToAges.put(qrEdge, 0);
                Node fNode = this.unitsToNodes.get(f);
                Edge frEdge = Edges.undirectedEdge(fNode, rNode);
                this.graph.addEdge(frEdge);
                this.edgesToAges.put(frEdge, 0);
                Edge qfEdge = this.graph.getEdge(qNode, fNode);
                this.removeEdge(qfEdge);
                this.print("Adding new node, " + qNode + "---" + rNode + "---" + fNode);
                for (Unit unit : this.unitsToErrors.keySet()) {
                    double error = this.unitsToErrors.get(unit);
                    error -= this.getAlpha() * error;
                    this.unitsToErrors.put(unit, error);
                }
                Unit r = this.nodesToUnits.get(rNode);
                this.unitsToErrors.put(r, this.unitsToErrors.get(q));
            }
            for (Unit unit : this.unitsToErrors.keySet()) {
                double error = this.unitsToErrors.get(unit);
                error -= this.getBeta() * error;
                this.unitsToErrors.put(unit, error);
            }
        }
        List<List<Node>> components = GraphUtils.connectedComponents(this.graph);
        ArrayList<List<Integer>> clusters = new ArrayList<List<Integer>>();
        for (int i = 0; i < components.size(); ++i) {
            clusters.add(new ArrayList(components.size()));
        }
        List<Node> nodes = this.graph.getNodes();
        block10: for (int i = 0; i < data.rows(); ++i) {
            int j;
            double min = Double.POSITIVE_INFINITY;
            Node _node = null;
            for (j = 0; j < nodes.size(); ++j) {
                Unit unit;
                Node node = nodes.get(j);
                unit = this.nodesToUnits.get(node);
                double d = this.metric.dissimilarity(unit.getVector(), data.viewRow(i));
                if (!(d < min)) continue;
                min = d;
                _node = node;
            }
            System.out.println("Node for " + i + " is " + _node);
            for (j = 0; j < components.size(); ++j) {
                if (!components.get(j).contains(_node)) continue;
                ((List)clusters.get(j)).add(i);
                System.out.println("Component " + j + " contains that node: " + components.get(j));
                continue block10;
            }
        }
        this.clusters = clusters;
    }

    @Override
    public List<List<Integer>> getClusters() {
        return this.clusters;
    }

    @Override
    public DoubleMatrix2D getPrototypes() {
        return null;
    }

    private void print(String s) {
        if (this.isVerbose()) {
            System.out.println(s);
        }
    }

    private void removeUnit(Unit unit) {
        Node node = this.unitsToNodes.get(unit);
        if (!this.graph.getAdjacentNodes(node).isEmpty()) {
            throw new IllegalArgumentException();
        }
        this.graph.removeNode(node);
        this.units.remove(unit);
        this.unitsToNodes.remove(unit);
        this.nodesToUnits.remove(node);
    }

    private void removeEdge(Edge edge) {
        this.graph.removeEdge(edge);
        this.edgesToAges.remove(edge);
    }

    private Unit neighborWithMaximumError(Unit q) {
        Unit f = null;
        double max = Double.NEGATIVE_INFINITY;
        for (Node node : this.neighbors(q)) {
            Unit _f = this.nodesToUnits.get(node);
            if (!(this.unitsToErrors.get(_f) > max)) continue;
            max = this.unitsToErrors.get(_f);
            f = _f;
        }
        if (f == null) {
            throw new IllegalArgumentException();
        }
        return f;
    }

    private List<Node> neighbors(Unit q) {
        return this.graph.getAdjacentNodes(this.unitsToNodes.get(q));
    }

    private Unit unitWithMaximumError() {
        double max = Double.NEGATIVE_INFINITY;
        Unit q = null;
        for (Unit unit : this.unitsToErrors.keySet()) {
            if (!(this.unitsToErrors.get(unit) > max)) continue;
            max = this.unitsToErrors.get(unit);
            q = unit;
        }
        if (q == null) {
            throw new IllegalArgumentException();
        }
        return q;
    }

    private boolean shouldStop() {
        return this.unitIndex > this.getMaxUnits();
    }

    private void moveUnit(Unit unit, DoubleMatrix1D signal, double epsilon) {
        int j;
        for (j = 0; j < signal.size(); ++j) {
            this.delta.set(j, signal.get(j) - unit.getVector().get(j));
        }
        for (j = 0; j < signal.size(); ++j) {
            this.delta.set(j, epsilon * this.delta.get(j));
        }
        unit.moveVector(this.delta);
    }

    private void incrementError(Unit s1, DoubleMatrix1D signal) {
        double error = this.unitsToErrors.get(s1);
        this.unitsToErrors.put(s1, error += this.getMetric().dissimilarity(s1.getVector(), signal));
    }

    private void calculateDistancesToSignal(DoubleMatrix1D signal) {
        for (Unit unit : this.units) {
            unit.calculateDistanceToSignal(signal);
        }
    }

    private Node addUnit(DoubleMatrix1D vector) {
        Unit unit = new Unit(vector);
        String name = "X" + ++this.unitIndex;
        this.units.add(unit);
        GraphNode node = new GraphNode(name);
        this.graph.addNode(node);
        this.unitsToNodes.put(unit, node);
        this.nodesToUnits.put(node, unit);
        this.unitsToNodes.put(unit, node);
        this.unitsToErrors.put(unit, 0.0);
        return node;
    }

    public DoubleMatrix2D getUnitsAsMatrix() {
        DenseDoubleMatrix2D matrix = new DenseDoubleMatrix2D(this.units.size(), this.units.get(0).getVector().size());
        for (int i = 0; i < matrix.rows(); ++i) {
            for (int j = 0; j < matrix.columns(); ++j) {
                matrix.set(i, j, this.units.get(i).getVector().get(j));
            }
        }
        return matrix;
    }

    public int getAgeMax() {
        return this.ageMax;
    }

    public void setAgeMax(int ageMax) {
        this.ageMax = ageMax;
    }

    public int getLambda() {
        return this.lambda;
    }

    public void setLambda(int lambda) {
        this.lambda = lambda;
    }

    public double getAlpha() {
        return this.alpha;
    }

    public void setAlpha(double alpha) {
        this.alpha = alpha;
    }

    public double getBeta() {
        return this.beta;
    }

    public void setBeta(double beta) {
        this.beta = beta;
    }

    public double getEpsilonB() {
        return this.epsilonB;
    }

    public void setEpsilonB(double epsilonB) {
        this.epsilonB = epsilonB;
    }

    public double getEpsilonN() {
        return this.epsilonN;
    }

    public void setEpsilonN(double epsilonN) {
        this.epsilonN = epsilonN;
    }

    public int getMaxUnits() {
        return this.maxUnits;
    }

    public void setMaxUnits(int maxUnits) {
        this.maxUnits = maxUnits;
    }

    public String toString() {
        StringBuilder buf = new StringBuilder();
        buf.append("Connected components: ");
        List<List<Node>> components = GraphUtils.connectedComponents(this.graph);
        for (int i = 0; i < components.size(); ++i) {
            buf.append("\n" + i + ": ");
            buf.append(components.get(i));
        }
        buf.append("\nGraph = ");
        buf.append(this.graph);
        return buf.toString();
    }

    private DoubleMatrix1D nextSignal() {
        int i = RandomUtil.getInstance().nextInt(this.data.rows());
        ++this.numSignalsGenerated;
        return this.data.viewRow(i);
    }

    public boolean isVerbose() {
        return this.verbose;
    }

    @Override
    public void setVerbose(boolean verbose) {
        this.verbose = verbose;
    }

    public Dissimilarity getMetric() {
        return this.metric;
    }

    public void setMetric(Dissimilarity metric) {
        this.metric = metric;
    }

    private class Unit
    implements Comparable {
        private DoubleMatrix1D vector;
        private double distance = Double.NaN;

        public Unit(DoubleMatrix1D unitVector) {
            this.vector = unitVector;
        }

        public void moveVector(DoubleMatrix1D delta) {
            for (int i = 0; i < this.vector.size(); ++i) {
                this.vector.set(i, this.vector.get(i) + delta.get(i));
            }
        }

        public void calculateDistanceToSignal(DoubleMatrix1D signal) {
            this.distance = this.distance(this.vector, signal);
        }

        private double distance(DoubleMatrix1D vector, DoubleMatrix1D signal) {
            return Gng.this.getMetric().dissimilarity(vector, signal);
        }

        public int compareTo(Object o) {
            Unit other = (Unit)o;
            return (int)Math.signum(this.distance - other.distance);
        }

        public DoubleMatrix1D getVector() {
            return this.vector;
        }

        public String toString() {
            return "" + this.vector;
        }
    }
}

