/*
 * Decompiled with CFR 0.152.
 */
package edu.cmu.tetrad.search;

import edu.cmu.tetrad.data.DataSet;
import edu.cmu.tetrad.graph.EdgeListGraph;
import edu.cmu.tetrad.graph.Graph;
import edu.cmu.tetrad.graph.Node;
import edu.cmu.tetrad.search.score.Score;
import edu.cmu.tetrad.search.utils.GrowShrinkTree;
import edu.cmu.tetrad.util.StatUtils;
import java.util.ArrayList;
import java.util.HashMap;
import java.util.HashSet;
import java.util.List;
import java.util.Map;
import org.apache.commons.math3.util.FastMath;

public class DirectLingam {
    private final DataSet dataset;
    private final List<Node> variables;
    private final Map<Node, GrowShrinkTree> gsts;

    public DirectLingam(DataSet dataset, Score score) {
        this.dataset = dataset;
        this.variables = dataset.getVariables();
        this.gsts = new HashMap<Node, GrowShrinkTree>();
        int i = 0;
        HashMap<Node, Integer> index = new HashMap<Node, Integer>();
        for (Node node : this.variables) {
            index.put(node, i++);
            this.gsts.put(node, new GrowShrinkTree(score, index, node));
        }
    }

    public Graph search() {
        ArrayList<Node> U = new ArrayList<Node>(this.variables);
        HashMap<Node, double[]> R = new HashMap<Node, double[]>();
        double[][] X = this.dataset.getDoubleData().transpose().toArray();
        for (int i = 0; i < X.length; ++i) {
            this.standardize(X[i]);
            R.put(this.variables.get(i), X[i]);
        }
        HashSet<Node> K = new HashSet<Node>();
        EdgeListGraph g = new EdgeListGraph(this.variables);
        while (!U.isEmpty()) {
            Node m = this.getNext(U, R);
            U.remove(m);
            for (Node x : U) {
                R.put(x, this.residuals((double[])R.get(x), (double[])R.get(m)));
            }
            K.add(m);
            HashSet<Node> parents = new HashSet<Node>();
            this.gsts.get(m).trace(K, K, parents);
            for (Node x : parents) {
                g.addDirectedEdge(x, m);
            }
        }
        return g;
    }

    private Node getNext(List<Node> U, Map<Node, double[]> R) {
        Node m = U.get(0);
        double best = Double.POSITIVE_INFINITY;
        for (Node x : U) {
            double curr = 0.0;
            double entx = StatUtils.maxEntApprox(R.get(x));
            for (Node y : U) {
                if (x == y) continue;
                double[] rxy = this.residuals(R.get(x), R.get(y));
                double[] ryx = this.residuals(R.get(y), R.get(x));
                double lr = StatUtils.maxEntApprox(R.get(y)) - entx;
                curr += FastMath.pow(FastMath.min(0.0, lr += StatUtils.maxEntApprox(rxy) - StatUtils.maxEntApprox(ryx)), 2);
            }
            if (!(curr < best)) continue;
            best = curr;
            m = x;
        }
        return m;
    }

    private void standardize(double[] x) {
        int n = x.length;
        double mu = 0.0;
        double std = 0.0;
        for (double v : x) {
            mu += v;
            std += FastMath.pow(v, 2);
        }
        std = FastMath.sqrt(std / (double)n - FastMath.pow(mu /= (double)n, 2));
        for (int i = 0; i < n; ++i) {
            x[i] = (x[i] - mu) / std;
        }
    }

    private double[] residuals(double[] x, double[] y) {
        int n = x.length;
        double cov = 0.0;
        double var = 0.0;
        for (int i = 0; i < n; ++i) {
            cov += x[i] * y[i];
            var += FastMath.pow(y[i], 2);
        }
        double b = cov / var;
        double[] r = new double[n];
        for (int i = 0; i < n; ++i) {
            r[i] = x[i] - b * y[i];
        }
        return r;
    }
}

