/*
 * Decompiled with CFR 0.152.
 */
package edu.cmu.tetrad.cluster;

import cern.colt.matrix.DoubleMatrix1D;
import cern.colt.matrix.DoubleMatrix2D;
import cern.colt.matrix.impl.DenseDoubleMatrix2D;
import edu.cmu.tetrad.cluster.ClusteringAlgorithm;
import edu.cmu.tetrad.cluster.metrics.Dissimilarity;
import edu.cmu.tetrad.cluster.metrics.SquaredErrorLoss;
import edu.cmu.tetrad.util.Point;
import java.util.ArrayList;
import java.util.Collections;
import java.util.LinkedList;
import java.util.List;

public class Wards
implements ClusteringAlgorithm {
    private List<Point> points;
    private TreeNode resultTreeNode;
    private int depth = 3;
    private Dissimilarity metric = new SquaredErrorLoss();
    private boolean verbose = true;

    private Wards() {
    }

    public static Wards initialize() {
        return new Wards();
    }

    @Override
    public void cluster(DoubleMatrix2D data) {
        this.points = new ArrayList<Point>();
        for (int i = 0; i < data.rows(); ++i) {
            this.points.add(new Point(data.viewRow(i)));
        }
        if (this.isVerbose()) {
            System.out.println("# thresholded points = " + this.points.size());
        }
        LinkedList<TreeNode> nodes = new LinkedList<TreeNode>();
        DenseDoubleMatrix2D dissimilarities = new DenseDoubleMatrix2D(this.points.size(), this.points.size());
        JoinTreeNode resultNode = null;
        for (Point point : this.points) {
            LeafTreeNode node = new LeafTreeNode(point);
            nodes.add(node);
        }
        for (int i = 0; i < this.points.size(); ++i) {
            for (int j = 0; j < i; ++j) {
                LeafTreeNode a = (LeafTreeNode)nodes.get(i);
                LeafTreeNode b = (LeafTreeNode)nodes.get(j);
                DoubleMatrix1D vectorA = a.getPoints().get(0).getVector();
                DoubleMatrix1D vectorB = b.getPoints().get(0).getVector();
                double dissimilarity = this.metric.dissimilarity(vectorA, vectorB);
                dissimilarities.set(i, j, dissimilarity);
                dissimilarities.set(j, i, dissimilarity);
            }
        }
        if (this.isVerbose()) {
            System.out.println("Matrix constructed.");
        }
        while (nodes.size() != 1) {
            JoinTreeNode aUb;
            DissimilarityResult dissimilarityResult = this.findLeastDissimilarity(dissimilarities);
            int aIndex = dissimilarityResult.getAIndex();
            int bIndex = dissimilarityResult.getBIndex();
            if (aIndex == -1) break;
            TreeNode a = (TreeNode)nodes.get(aIndex);
            TreeNode b = (TreeNode)nodes.get(bIndex);
            if (this.isVerbose()) {
                System.out.println("leastDissimilarity = " + dissimilarityResult.getDissimilarity());
            }
            DoubleMatrix1D aRowCopy = this.copyRow(dissimilarities, aIndex);
            DoubleMatrix1D bRowCopy = this.copyRow(dissimilarities, bIndex);
            resultNode = aUb = new JoinTreeNode(a, b);
            nodes.set(aIndex, null);
            nodes.set(bIndex, null);
            nodes.set(bIndex, aUb);
            this.clearIndex(nodes, dissimilarities, aIndex);
            this.clearIndex(nodes, dissimilarities, bIndex);
            for (int cIndex = 0; cIndex < nodes.size(); ++cIndex) {
                TreeNode c = (TreeNode)nodes.get(cIndex);
                if (c == null || cIndex == aIndex || cIndex == bIndex) continue;
                double wA = a.getWeight();
                double wB = b.getWeight();
                double wC = c.getWeight();
                double deltaAC = aRowCopy.get(cIndex);
                double deltaBC = bRowCopy.get(cIndex);
                double deltaAB = aRowCopy.get(bIndex);
                double dissimilarity = ((wA + wB) * deltaAC + (wB + wC) * deltaBC + wC * deltaAB) / (wA + wB + wC);
                dissimilarities.set(cIndex, bIndex, dissimilarity);
                dissimilarities.set(bIndex, cIndex, dissimilarity);
            }
        }
        this.resultTreeNode = resultNode;
    }

    @Override
    public List<List<Integer>> getClusters() {
        return this.getClustersAtDepth(this.resultTreeNode, this.depth);
    }

    public List<List<Integer>> getClustersAtDepth(TreeNode node, int depth) {
        List<TreeNode> nodesAtDepth = this.getNodesAtDepth(this.resultTreeNode, depth);
        ArrayList<List<Integer>> clusters = new ArrayList<List<Integer>>();
        for (TreeNode aNodesAtDepth : nodesAtDepth) {
            ArrayList<Integer> cluster = new ArrayList<Integer>();
            List<Point> points = aNodesAtDepth.getPoints();
            for (Point p : points) {
                int index = this.points.indexOf(p);
                cluster.add(index);
            }
            clusters.add(cluster);
        }
        return clusters;
    }

    @Override
    public DoubleMatrix2D getPrototypes() {
        return null;
    }

    public TreeNode clusterResult() {
        return this.resultTreeNode;
    }

    public String toString() {
        StringBuilder buf = new StringBuilder();
        buf.append("Ward's method clustering.");
        for (int i = 0; i < 6; ++i) {
            List<TreeNode> nodes = this.getNodesAtDepth(this.resultTreeNode, i);
            buf.append("\n At depth = ").append(i).append(" there are ").append(nodes.size()).append(" clusters with these sizes:");
            for (int j = 0; j < nodes.size(); ++j) {
                TreeNode node = nodes.get(j);
                List<Point> points = node.getPoints();
                int size = points.size();
                buf.append("\n\t").append(j).append(". ").append(size);
            }
        }
        return buf.toString();
    }

    public List<TreeNode> getNodesAtDepth(TreeNode treeNode, int depth) {
        if (treeNode == null) {
            throw new IllegalArgumentException("The given tree node is null.");
        }
        List<TreeNode> nodes = Collections.singletonList(treeNode);
        for (int i = 0; i < depth; ++i) {
            ArrayList<TreeNode> _nodes = new ArrayList<TreeNode>();
            for (TreeNode node : nodes) {
                if (node instanceof JoinTreeNode) {
                    JoinTreeNode _node = (JoinTreeNode)node;
                    _nodes.add(_node.getNode1());
                    _nodes.add(_node.getNode2());
                    continue;
                }
                _nodes.add(node);
            }
            nodes = _nodes;
        }
        return nodes;
    }

    private DoubleMatrix1D copyRow(DoubleMatrix2D dissimilarities, int aIndex) {
        return dissimilarities.viewRow(aIndex).copy();
    }

    private void clearIndex(List<TreeNode> nodes, DoubleMatrix2D dissimilarities, int index) {
        for (int i = 0; i < nodes.size(); ++i) {
            dissimilarities.set(i, index, Double.NaN);
            dissimilarities.set(index, i, Double.NaN);
        }
    }

    private DissimilarityResult findLeastDissimilarity(DoubleMatrix2D dissimilarities) {
        double leastDissimilarity = Double.POSITIVE_INFINITY;
        int aIndex = -1;
        int bIndex = -1;
        for (int i = 0; i < this.points.size(); ++i) {
            for (int j = 0; j < i; ++j) {
                if (!(dissimilarities.get(i, j) < leastDissimilarity)) continue;
                leastDissimilarity = dissimilarities.get(i, j);
                aIndex = i;
                bIndex = j;
            }
        }
        return new DissimilarityResult(aIndex, bIndex, leastDissimilarity);
    }

    public int getDepth() {
        return this.depth;
    }

    public void setDepth(int depth) {
        this.depth = depth;
    }

    public boolean isVerbose() {
        return this.verbose;
    }

    @Override
    public void setVerbose(boolean verbose) {
        this.verbose = verbose;
    }

    private static class DissimilarityResult {
        private int aIndex;
        private int bIndex;
        private double dissimilarity;

        public DissimilarityResult(int aIndex, int bIndex, double dissimilarity) {
            this.aIndex = aIndex;
            this.bIndex = bIndex;
            this.dissimilarity = dissimilarity;
        }

        public int getAIndex() {
            return this.aIndex;
        }

        public int getBIndex() {
            return this.bIndex;
        }

        public double getDissimilarity() {
            return this.dissimilarity;
        }

        public String toString() {
            return "aIndex = " + this.aIndex + " bIndex = " + this.bIndex + " dissimilarity = " + this.dissimilarity;
        }
    }

    public static class JoinTreeNode
    implements TreeNode {
        private TreeNode node1;
        private TreeNode node2;
        private int weight;

        public JoinTreeNode(TreeNode node1, TreeNode node2) {
            if (node1 == null || node2 == null) {
                throw new IllegalArgumentException();
            }
            this.node1 = node1;
            this.node2 = node2;
            this.weight = node1.getWeight() + node2.getWeight();
        }

        public TreeNode getNode1() {
            return this.node1;
        }

        public TreeNode getNode2() {
            return this.node2;
        }

        @Override
        public int getWeight() {
            return this.weight;
        }

        @Override
        public List<Point> getPoints() {
            ArrayList<Point> points = new ArrayList<Point>();
            points.addAll(this.node1.getPoints());
            points.addAll(this.node2.getPoints());
            return points;
        }

        @Override
        public String toString() {
            StringBuilder buf = new StringBuilder();
            List<Point> points = this.getPoints();
            buf.append("Points: ");
            for (int i = 0; i < points.size(); ++i) {
                buf.append("\n ").append(i + 1).append(". ").append(points.get(i));
            }
            return this.getPoints().toString();
        }
    }

    public static class LeafTreeNode
    implements TreeNode {
        private Point point;
        private int weight = 1;

        public LeafTreeNode(Point point) {
            this.point = point;
        }

        @Override
        public int getWeight() {
            return this.weight;
        }

        @Override
        public List<Point> getPoints() {
            LinkedList<Point> points = new LinkedList<Point>();
            points.add(this.point);
            return points;
        }

        @Override
        public String toString() {
            StringBuilder buf = new StringBuilder();
            buf.append("Single Point: ");
            buf.append("\n 1. ").append(this.point);
            return this.point.toString();
        }
    }

    public static interface TreeNode {
        public int getWeight();

        public List<Point> getPoints();

        public String toString();
    }
}

