/*
 * Decompiled with CFR 0.152.
 */
package edu.cmu.tetrad.search;

import edu.cmu.tetrad.data.DataSet;
import edu.cmu.tetrad.data.DataUtils;
import edu.cmu.tetrad.data.Knowledge;
import edu.cmu.tetrad.graph.Edge;
import edu.cmu.tetrad.graph.EdgeListGraph;
import edu.cmu.tetrad.graph.Graph;
import edu.cmu.tetrad.graph.GraphUtils;
import edu.cmu.tetrad.graph.Node;
import edu.cmu.tetrad.graph.NodePair;
import edu.cmu.tetrad.regression.RegressionDataset;
import edu.cmu.tetrad.regression.RegressionResult;
import edu.cmu.tetrad.search.Fas;
import edu.cmu.tetrad.search.Fges;
import edu.cmu.tetrad.search.GraphSearch;
import edu.cmu.tetrad.search.IndependenceTest;
import edu.cmu.tetrad.search.Score;
import edu.cmu.tetrad.search.SearchGraphUtils;
import edu.cmu.tetrad.util.Matrix;
import edu.cmu.tetrad.util.MillisecondTimes;
import edu.cmu.tetrad.util.StatUtils;
import edu.cmu.tetrad.util.SublistGenerator;
import edu.cmu.tetrad.util.TetradLogger;
import java.text.DecimalFormat;
import java.text.NumberFormat;
import java.util.ArrayList;
import java.util.HashSet;
import java.util.List;
import org.apache.commons.math3.linear.SingularMatrixException;
import org.apache.commons.math3.util.FastMath;

public final class Fask
implements GraphSearch {
    private final IndependenceTest test;
    private final Score score;
    private final DataSet dataSet;
    private final RegressionDataset regressionDataset;
    double[][] D;
    private Graph externalGraph;
    private long elapsed;
    private int depth = -1;
    private Knowledge knowledge = new Knowledge();
    private double skewEdgeThreshold;
    private double twoCycleScreeningCutoff;
    private double orientationCutoff;
    private double orientationAlpha;
    private double delta;
    private boolean empirical = true;
    private AdjacencyMethod adjacencyMethod = AdjacencyMethod.FAS_STABLE;
    private LeftRight leftRight = LeftRight.RSKEW;
    private Graph graph;

    public Fask(DataSet dataSet, Score score, IndependenceTest test) {
        if (dataSet == null) {
            throw new NullPointerException("Data set not provided.");
        }
        if (!dataSet.isContinuous()) {
            throw new IllegalArgumentException("For FASK, the dataset must be entirely continuous");
        }
        this.dataSet = dataSet;
        this.test = test;
        this.score = score;
        this.regressionDataset = new RegressionDataset(dataSet);
        this.orientationCutoff = StatUtils.getZForAlpha(0.01);
        this.orientationAlpha = 0.01;
    }

    private static double cu(double[] x, double[] y, double[] condition) {
        double exy = 0.0;
        int n = 0;
        for (int k = 0; k < x.length; ++k) {
            if (!(condition[k] > 0.0)) continue;
            exy += x[k] * y[k];
            ++n;
        }
        return exy / (double)n;
    }

    private static double correxp(double[] x, double[] y, double[] z) {
        return Fask.E(x, y, z) / FastMath.sqrt(Fask.E(x, x, z) * Fask.E(y, y, z));
    }

    private static double E(double[] x, double[] y, double[] z) {
        double exy = 0.0;
        int n = 0;
        for (int k = 0; k < x.length; ++k) {
            if (!(z[k] > 0.0)) continue;
            exy += x[k] * y[k];
            ++n;
        }
        return exy / (double)n;
    }

    @Override
    public Graph search() {
        Node Y;
        Node X;
        Graph G;
        Object fas;
        long start = MillisecondTimes.timeMillis();
        DecimalFormat nf = new DecimalFormat("0.000");
        DataSet dataSet = DataUtils.standardizeData(this.dataSet);
        List<Node> variables = dataSet.getVariables();
        double[][] lrs = this.getLrScores();
        for (int i = 0; i < variables.size(); ++i) {
            System.out.println("Skewness of " + variables.get(i) + " = " + StatUtils.skewness(this.D[i]));
        }
        TetradLogger.getInstance().forceLogMessage("FASK v. 2.0");
        TetradLogger.getInstance().forceLogMessage("");
        TetradLogger.getInstance().forceLogMessage("# variables = " + dataSet.getNumColumns());
        TetradLogger.getInstance().forceLogMessage("N = " + dataSet.getNumRows());
        TetradLogger.getInstance().forceLogMessage("Skewness edge threshold = " + this.skewEdgeThreshold);
        TetradLogger.getInstance().forceLogMessage("Orientation Alpha = " + this.orientationAlpha);
        TetradLogger.getInstance().forceLogMessage("2-cycle threshold = " + this.twoCycleScreeningCutoff);
        TetradLogger.getInstance().forceLogMessage("");
        if (this.adjacencyMethod == AdjacencyMethod.FAS_STABLE) {
            fas = new Fas(this.test);
            ((Fas)fas).setStable(true);
            ((Fas)fas).setVerbose(false);
            ((Fas)fas).setKnowledge(this.knowledge);
            G = ((Fas)fas).search();
        } else if (this.adjacencyMethod == AdjacencyMethod.FGES) {
            fas = new Fges(this.score);
            ((Fges)fas).setVerbose(false);
            ((Fges)fas).setKnowledge(this.knowledge);
            G = ((Fges)fas).search();
        } else if (this.adjacencyMethod == AdjacencyMethod.EXTERNAL_GRAPH) {
            if (this.getExternalGraph() == null) {
                throw new IllegalStateException("An external graph was not supplied.");
            }
            Graph g1 = new EdgeListGraph(this.getExternalGraph().getNodes());
            for (Edge edge : this.getExternalGraph().getEdges()) {
                Node y;
                Node x = edge.getNode1();
                if (g1.isAdjacentTo(x, y = edge.getNode2())) continue;
                g1.addUndirectedEdge(x, y);
            }
            G = g1 = GraphUtils.replaceNodes(g1, dataSet.getVariables());
        } else if (this.adjacencyMethod == AdjacencyMethod.NONE) {
            G = new EdgeListGraph(variables);
        } else {
            throw new IllegalStateException("That method was not configured: " + (Object)((Object)this.adjacencyMethod));
        }
        G = GraphUtils.replaceNodes(G, dataSet.getVariables());
        TetradLogger.getInstance().forceLogMessage("");
        SearchGraphUtils.pcOrientbk(this.knowledge, G, G.getNodes());
        EdgeListGraph graph = new EdgeListGraph(G.getNodes());
        TetradLogger.getInstance().forceLogMessage("X\tY\tMethod\tLR\tEdge");
        int V = variables.size();
        ArrayList<NodePair> twoCycles = new ArrayList<NodePair>();
        for (int i = 0; i < V; ++i) {
            for (int j = i + 1; j < V; ++j) {
                X = variables.get(i);
                Y = variables.get(j);
                double[] x = this.D[i];
                double[] y = this.D[j];
                double cx = Fask.correxp(x, y, x);
                double cy = Fask.correxp(x, y, y);
                if (!G.isAdjacentTo(X, Y) && !(FastMath.abs(cx - cy) > this.skewEdgeThreshold)) continue;
                double lr = lrs[i][j];
                if (this.edgeForbiddenByKnowledge(X, Y) && this.edgeForbiddenByKnowledge(Y, X)) {
                    TetradLogger.getInstance().forceLogMessage(X + "\t" + Y + "\tknowledge_forbidden\t" + nf.format(lr) + "\t" + X + "<->" + Y);
                    continue;
                }
                if (this.knowledgeOrients(X, Y)) {
                    TetradLogger.getInstance().forceLogMessage(X + "\t" + Y + "\tknowledge\t" + nf.format(lr) + "\t" + X + "-->" + Y);
                    graph.addDirectedEdge(X, Y);
                    continue;
                }
                if (this.knowledgeOrients(Y, X)) {
                    TetradLogger.getInstance().forceLogMessage(X + "\t" + Y + "\tknowledge\t" + nf.format(lr) + "\t" + X + "<--" + Y);
                    graph.addDirectedEdge(Y, X);
                    continue;
                }
                if (this.zeroDiff(i, j, this.D)) {
                    TetradLogger.getInstance().forceLogMessage(X + "\t" + Y + "\t2-cycle Prescreen\t" + nf.format(lr) + "\t" + X + "...TC?..." + Y);
                    System.out.println(X + " " + Y + " lr = " + lr + " zero");
                    continue;
                }
                if (this.twoCycleScreeningCutoff > 0.0 && FastMath.abs(this.faskLeftRightV2(x, y)) < this.twoCycleScreeningCutoff) {
                    TetradLogger.getInstance().forceLogMessage(X + "\t" + Y + "\t2-cycle Prescreen\t" + nf.format(lr) + "\t" + X + "...TC?..." + Y);
                    twoCycles.add(new NodePair(X, Y));
                    System.out.println(X + " " + Y + " lr = " + lr + " zero");
                }
                if (lr > 0.0) {
                    TetradLogger.getInstance().forceLogMessage(X + "\t" + Y + "\tleft-right\t" + nf.format(lr) + "\t" + X + "-->" + Y);
                    graph.addDirectedEdge(X, Y);
                    continue;
                }
                if (!(lr < 0.0)) continue;
                TetradLogger.getInstance().forceLogMessage(Y + "\t" + X + "\tleft-right\t" + nf.format(lr) + "\t" + Y + "-->" + X);
                graph.addDirectedEdge(Y, X);
            }
        }
        if (this.twoCycleScreeningCutoff > 0.0 && this.orientationAlpha == 0.0) {
            for (NodePair edge : twoCycles) {
                X = edge.getFirst();
                Y = edge.getSecond();
                graph.removeEdges(X, Y);
                graph.addDirectedEdge(X, Y);
                graph.addDirectedEdge(Y, X);
                this.logTwoCycle(nf, variables, this.D, X, Y, "2-cycle Pre-screen");
            }
        } else if (this.twoCycleScreeningCutoff > 0.0 && this.orientationAlpha > 0.0) {
            for (NodePair edge : twoCycles) {
                int j;
                X = edge.getFirst();
                Y = edge.getSecond();
                int i = variables.indexOf(X);
                if (!this.twoCycleTest(i, j = variables.indexOf(Y), this.D, graph, variables)) continue;
                graph.removeEdges(X, Y);
                graph.addDirectedEdge(X, Y);
                graph.addDirectedEdge(Y, X);
                this.logTwoCycle(nf, variables, this.D, X, Y, "2-cycle Screened then Tested");
            }
        }
        long stop = MillisecondTimes.timeMillis();
        this.elapsed = stop - start;
        this.graph = graph;
        return graph;
    }

    private void logTwoCycle(NumberFormat nf, List<Node> variables, double[][] d, Node X, Node Y, String type) {
        int i = variables.indexOf(X);
        int j = variables.indexOf(Y);
        double[] x = d[i];
        double[] y = d[j];
        double lr = this.leftRight(x, y);
        TetradLogger.getInstance().forceLogMessage(X + "\t" + Y + "\t" + type + "\t" + nf.format(lr) + "\t" + X + "<=>" + Y);
    }

    public double[][] getB() {
        if (this.graph == null) {
            this.search();
        }
        List<Node> nodes = this.dataSet.getVariables();
        double[][] B = new double[nodes.size()][nodes.size()];
        for (int j = 0; j < nodes.size(); ++j) {
            Node y = nodes.get(j);
            List<Node> pary = this.graph.getParents(y);
            RegressionResult result = this.regressionDataset.regress(y, pary);
            double[] coef = result.getCoef();
            for (int i = 0; i < pary.size(); ++i) {
                B[nodes.indexOf((Object)pary.get((int)i))][j] = coef[i + 1];
            }
        }
        return B;
    }

    public double[][] getLrScores() {
        List<Node> variables = this.dataSet.getVariables();
        double[][] D = DataUtils.standardizeData(this.dataSet).getDoubleData().transpose().toArray();
        double[][] lr = new double[variables.size()][variables.size()];
        for (int i = 0; i < variables.size(); ++i) {
            for (int j = 0; j < variables.size(); ++j) {
                lr[i][j] = this.leftRight(D[i], D[j]);
            }
        }
        this.D = D;
        return lr;
    }

    public int getDepth() {
        return this.depth;
    }

    public void setDepth(int depth) {
        this.depth = depth;
    }

    public long getElapsedTime() {
        return this.elapsed;
    }

    public Knowledge getKnowledge() {
        return this.knowledge;
    }

    public void setKnowledge(Knowledge knowledge) {
        this.knowledge = new Knowledge(knowledge);
    }

    public Graph getExternalGraph() {
        return this.externalGraph;
    }

    public void setExternalGraph(Graph externalGraph) {
        this.externalGraph = externalGraph;
    }

    public void setSkewEdgeThreshold(double skewEdgeThreshold) {
        this.skewEdgeThreshold = skewEdgeThreshold;
    }

    public void setTwoCycleScreeningCutoff(double twoCycleScreeningCutoff) {
        if (twoCycleScreeningCutoff < 0.0) {
            throw new IllegalStateException("Two cycle screening threshold must be >= 0");
        }
        this.twoCycleScreeningCutoff = twoCycleScreeningCutoff;
    }

    public void setOrientationAlpha(double orientationAlpha) {
        if (orientationAlpha < 0.0 || orientationAlpha > 1.0) {
            throw new IllegalArgumentException("Two cycle testing alpha should be in [0, 1].");
        }
        this.orientationCutoff = StatUtils.getZForAlpha(orientationAlpha);
        this.orientationAlpha = orientationAlpha;
    }

    public void setLeftRight(LeftRight leftRight) {
        this.leftRight = leftRight;
    }

    public void setAdjacencyMethod(AdjacencyMethod adjacencyMethod) {
        this.adjacencyMethod = adjacencyMethod;
    }

    public void setDelta(double delta) {
        this.delta = delta;
    }

    public void setEmpirical(boolean empirical) {
        this.empirical = empirical;
    }

    public double leftRight(Node X, Node Y) {
        List<Node> variables = this.dataSet.getVariables();
        int i = -1;
        for (int k = 0; k < variables.size(); ++k) {
            if (!X.getName().equals(variables.get(k).getName())) continue;
            i = k;
        }
        int j = -1;
        for (int k = 0; k < variables.size(); ++k) {
            if (!Y.getName().equals(variables.get(k).getName())) continue;
            j = k;
        }
        double[] x = this.D[i];
        double[] y = this.D[j];
        return this.leftRight(x, y);
    }

    private double leftRight(double[] x, double[] y) {
        if (this.leftRight == LeftRight.FASK1) {
            return this.faskLeftRightV1(x, y);
        }
        if (this.leftRight == LeftRight.FASK2) {
            return this.faskLeftRightV2(x, y);
        }
        if (this.leftRight == LeftRight.RSKEW) {
            return this.robustSkew(x, y);
        }
        if (this.leftRight == LeftRight.SKEW) {
            return this.skew(x, y);
        }
        if (this.leftRight == LeftRight.TANH) {
            return this.tanh(x, y);
        }
        throw new IllegalStateException("Left right rule not configured: " + (Object)((Object)this.leftRight));
    }

    private double faskLeftRightV2(double[] x, double[] y) {
        double sx = StatUtils.skewness(x);
        double sy = StatUtils.skewness(y);
        double r = StatUtils.correlation(x, y);
        double lr = Fask.correxp(x, y, x) - Fask.correxp(x, y, y);
        if (this.empirical) {
            lr *= FastMath.signum(sx) * FastMath.signum(sy);
        }
        if (r < this.delta) {
            lr *= -1.0;
        }
        return lr;
    }

    private double faskLeftRightV1(double[] x, double[] y) {
        double left = Fask.cu(x, y, x) / FastMath.sqrt(Fask.cu(x, x, x) * Fask.cu(y, y, x));
        double right = Fask.cu(x, y, y) / FastMath.sqrt(Fask.cu(x, x, y) * Fask.cu(y, y, y));
        double lr = left - right;
        double r = StatUtils.correlation(x, y);
        double sx = StatUtils.skewness(x);
        double sy = StatUtils.skewness(y);
        if (this.empirical) {
            r *= FastMath.signum(sx) * FastMath.signum(sy);
        }
        lr *= FastMath.signum(r);
        if (r < this.delta) {
            lr *= -1.0;
        }
        return lr;
    }

    private double robustSkew(double[] x, double[] y) {
        if (this.empirical) {
            x = this.correctSkewness(x, StatUtils.skewness(x));
            y = this.correctSkewness(y, StatUtils.skewness(y));
        }
        double[] lr = new double[x.length];
        for (int i = 0; i < x.length; ++i) {
            lr[i] = this.g(x[i]) * y[i] - x[i] * this.g(y[i]);
        }
        return StatUtils.correlation(x, y) * StatUtils.mean(lr);
    }

    private double skew(double[] x, double[] y) {
        if (this.empirical) {
            x = this.correctSkewness(x, StatUtils.skewness(x));
            y = this.correctSkewness(y, StatUtils.skewness(y));
        }
        double[] lr = new double[x.length];
        for (int i = 0; i < x.length; ++i) {
            lr[i] = x[i] * x[i] * y[i] - x[i] * y[i] * y[i];
        }
        return StatUtils.correlation(x, y) * StatUtils.mean(lr);
    }

    private double tanh(double[] x, double[] y) {
        if (this.empirical) {
            x = this.correctSkewness(x, StatUtils.skewness(x));
            y = this.correctSkewness(y, StatUtils.skewness(y));
        }
        double[] lr = new double[x.length];
        for (int i = 0; i < x.length; ++i) {
            lr[i] = x[i] * FastMath.tanh(y[i]) - FastMath.tanh(x[i]) * y[i];
        }
        return StatUtils.correlation(x, y) * StatUtils.mean(lr);
    }

    private double g(double x) {
        return FastMath.log(FastMath.cosh(FastMath.max(x, 0.0)));
    }

    private boolean knowledgeOrients(Node X, Node Y) {
        return this.knowledge.isForbidden(Y.getName(), X.getName()) || this.knowledge.isRequired(X.getName(), Y.getName());
    }

    private boolean edgeForbiddenByKnowledge(Node X, Node Y) {
        return this.knowledge.isForbidden(Y.getName(), X.getName()) && this.knowledge.isForbidden(X.getName(), Y.getName());
    }

    private double[] correctSkewness(double[] data, double sk) {
        double[] data2 = new double[data.length];
        for (int i = 0; i < data.length; ++i) {
            data2[i] = data[i] * FastMath.signum(sk);
        }
        return data2;
    }

    private boolean twoCycleTest(int i, int j, double[][] D, Graph G0, List<Node> V) {
        int[] choice;
        Node X = V.get(i);
        Node Y = V.get(j);
        double[] x = D[i];
        double[] y = D[j];
        HashSet<Node> adjSet = new HashSet<Node>(G0.getAdjacentNodes(X));
        adjSet.addAll(G0.getAdjacentNodes(Y));
        ArrayList<Node> adj = new ArrayList<Node>(adjSet);
        adj.remove(X);
        adj.remove(Y);
        SublistGenerator gen = new SublistGenerator(adj.size(), FastMath.min(this.depth, adj.size()));
        while ((choice = gen.next()) != null) {
            double pc2;
            double pc1;
            double pc;
            List<Node> _adj = GraphUtils.asList(choice, adj);
            double[][] _Z = new double[_adj.size()][];
            for (int f = 0; f < _adj.size(); ++f) {
                Node _z = _adj.get(f);
                int column = this.dataSet.getColumn(_z);
                _Z[f] = D[column];
            }
            try {
                pc = this.partialCorrelation(x, y, _Z, x, Double.NEGATIVE_INFINITY);
                pc1 = this.partialCorrelation(x, y, _Z, x, 0.0);
                pc2 = this.partialCorrelation(x, y, _Z, y, 0.0);
            }
            catch (SingularMatrixException e) {
                System.out.println("Singularity X = " + X + " Y = " + Y + " adj = " + adj);
                TetradLogger.getInstance().log("info", "Singularity X = " + X + " Y = " + Y + " adj = " + adj);
                continue;
            }
            int nc = StatUtils.getRows(x, x, 0.0, Double.NEGATIVE_INFINITY).size();
            int nc1 = StatUtils.getRows(x, x, 0.0, 1.0).size();
            int nc2 = StatUtils.getRows(y, y, 0.0, 1.0).size();
            double z = 0.5 * (FastMath.log(1.0 + pc) - FastMath.log(1.0 - pc));
            double z1 = 0.5 * (FastMath.log(1.0 + pc1) - FastMath.log(1.0 - pc1));
            double z2 = 0.5 * (FastMath.log(1.0 + pc2) - FastMath.log(1.0 - pc2));
            double zv1 = (z - z1) / FastMath.sqrt(1.0 / ((double)nc - 3.0) + 1.0 / ((double)nc1 - 3.0));
            double zv2 = (z - z2) / FastMath.sqrt(1.0 / ((double)nc - 3.0) + 1.0 / ((double)nc2 - 3.0));
            boolean rejected1 = FastMath.abs(zv1) > this.orientationCutoff;
            boolean rejected2 = FastMath.abs(zv2) > this.orientationCutoff;
            boolean possibleTwoCycle = false;
            if (zv1 < 0.0 && zv2 > 0.0 && rejected1) {
                possibleTwoCycle = true;
            } else if (zv1 > 0.0 && zv2 < 0.0 && rejected2) {
                possibleTwoCycle = true;
            } else if (rejected1 && rejected2) {
                possibleTwoCycle = true;
            }
            if (possibleTwoCycle) continue;
            return false;
        }
        return true;
    }

    private boolean zeroDiff(int i, int j, double[][] D) {
        double pc2;
        double pc1;
        double[] x = D[i];
        double[] y = D[j];
        try {
            pc1 = this.partialCorrelation(x, y, new double[0][], x, 0.0);
            pc2 = this.partialCorrelation(x, y, new double[0][], y, 0.0);
        }
        catch (SingularMatrixException e) {
            throw new RuntimeException(e);
        }
        int nc1 = StatUtils.getRows(x, x, 0.0, 1.0).size();
        int nc2 = StatUtils.getRows(y, y, 0.0, 1.0).size();
        double z1 = 0.5 * (FastMath.log(1.0 + pc1) - FastMath.log(1.0 - pc1));
        double z2 = 0.5 * (FastMath.log(1.0 + pc2) - FastMath.log(1.0 - pc2));
        double zv = (z1 - z2) / FastMath.sqrt(1.0 / ((double)nc1 - 3.0) + 1.0 / ((double)nc2 - 3.0));
        return FastMath.abs(zv) <= this.twoCycleScreeningCutoff;
    }

    private double partialCorrelation(double[] x, double[] y, double[][] z, double[] condition, double threshold) throws SingularMatrixException {
        double[][] cv = StatUtils.covMatrix(x, y, z, condition, threshold, 1.0);
        Matrix m = new Matrix(cv).transpose();
        return StatUtils.partialCorrelation(m);
    }

    public static enum AdjacencyMethod {
        FAS_STABLE,
        FGES,
        EXTERNAL_GRAPH,
        NONE;

    }

    public static enum LeftRight {
        FASK1,
        FASK2,
        RSKEW,
        SKEW,
        TANH;

    }
}

