/*
 * Decompiled with CFR 0.152.
 */
package edu.cmu.tetrad.search;

import edu.cmu.tetrad.data.DataSet;
import edu.cmu.tetrad.data.DataUtils;
import edu.cmu.tetrad.data.Knowledge;
import edu.cmu.tetrad.graph.Edge;
import edu.cmu.tetrad.graph.EdgeListGraph;
import edu.cmu.tetrad.graph.Edges;
import edu.cmu.tetrad.graph.Graph;
import edu.cmu.tetrad.graph.GraphUtils;
import edu.cmu.tetrad.graph.Node;
import edu.cmu.tetrad.search.Fas;
import edu.cmu.tetrad.search.IndTestScore;
import edu.cmu.tetrad.search.SearchGraphUtils;
import edu.cmu.tetrad.search.SemBicScoreMultiFas;
import edu.cmu.tetrad.util.Matrix;
import edu.cmu.tetrad.util.StatUtils;
import edu.cmu.tetrad.util.SublistGenerator;
import java.util.ArrayList;
import java.util.HashSet;
import java.util.List;
import org.apache.commons.math3.linear.SingularMatrixException;
import org.apache.commons.math3.util.FastMath;

public class MultiFaskV1 {
    private final SemBicScoreMultiFas score;
    private Graph externalGraph;
    private final int depth = -1;
    private double penaltyDiscount = 1.0;
    private double alpha = 1.0E-6;
    private Knowledge knowledge = new Knowledge();
    private double cutoff;
    private double delta = -0.2;
    private List<DataSet> dataSets;
    private final double[][][] data;

    public MultiFaskV1(List<DataSet> dataSets, SemBicScoreMultiFas score) {
        this.dataSets = dataSets;
        this.score = score;
        this.data = new double[dataSets.size()][][];
        for (int i = 0; i < dataSets.size(); ++i) {
            this.data[i] = dataSets.get(i).getDoubleData().transpose().toArray();
        }
    }

    public Graph search() {
        this.setCutoff(this.alpha);
        ArrayList<DataSet> standardSets = new ArrayList<DataSet>();
        for (DataSet set : this.dataSets) {
            DataSet dataSet = DataUtils.standardizeData(set);
            standardSets.add(dataSet);
        }
        this.dataSets = standardSets;
        List<Node> variables = this.dataSets.get(0).getVariables();
        IndTestScore test = new IndTestScore(this.score);
        System.out.println("FAS");
        Fas fas = new Fas(test);
        fas.setStable(true);
        fas.setDepth(this.getDepth());
        fas.setVerbose(false);
        fas.setKnowledge(this.knowledge);
        Graph G0 = fas.search();
        SearchGraphUtils.pcOrientbk(this.knowledge, G0, G0.getNodes());
        G0 = GraphUtils.replaceNodes(G0, this.dataSets.get(0).getVariables());
        System.out.println("Orientation");
        EdgeListGraph graph = new EdgeListGraph(variables);
        for (int i = 0; i < variables.size(); ++i) {
            for (int j = i + 1; j < variables.size(); ++j) {
                Node X = variables.get(i);
                Node Y = variables.get(j);
                double[][] _x = new double[this.dataSets.size()][];
                double[][] _y = new double[this.dataSets.size()][];
                double c1 = 0.0;
                double c2 = 0.0;
                for (int k = 0; k < this.dataSets.size(); ++k) {
                    double[] x = this.data[k][i];
                    double[] y = this.data[k][j];
                    _x[k] = x;
                    _y[k] = y;
                    c1 += StatUtils.cov(x, y, x, 0.0, 1.0)[1];
                    c2 += StatUtils.cov(x, y, y, 0.0, 1.0)[1];
                }
                if (!G0.isAdjacentTo(X, Y) && !(FastMath.abs(c1 - c2) / (double)this.dataSets.size() > 0.3)) continue;
                if (this.knowledgeOrients(X, Y)) {
                    graph.addDirectedEdge(X, Y);
                    continue;
                }
                if (this.knowledgeOrients(Y, X)) {
                    graph.addDirectedEdge(Y, X);
                    continue;
                }
                if (this.bidirected(_x, _y, G0, X, Y)) {
                    Edge edge1 = Edges.directedEdge(X, Y);
                    Edge edge2 = Edges.directedEdge(Y, X);
                    graph.addEdge(edge1);
                    graph.addEdge(edge2);
                    continue;
                }
                if (this.leftright(_x, _y)) {
                    graph.addDirectedEdge(X, Y);
                    continue;
                }
                graph.addDirectedEdge(Y, X);
            }
        }
        System.out.println();
        System.out.println("Done");
        return graph;
    }

    public double getPenaltyDiscount() {
        return this.penaltyDiscount;
    }

    public void setPenaltyDiscount(double penaltyDiscount) {
        this.penaltyDiscount = penaltyDiscount;
    }

    public void setAlpha(double alpha) {
        this.alpha = alpha;
    }

    public Knowledge getKnowledge() {
        return this.knowledge;
    }

    public void setKnowledge(Knowledge knowledge) {
        this.knowledge = knowledge;
    }

    public int getDepth() {
        return this.depth;
    }

    private boolean bidirected(double[][] x, double[][] y, Graph G0, Node X, Node Y) {
        HashSet<Node> adjSet = new HashSet<Node>(G0.getAdjacentNodes(X));
        adjSet.addAll(G0.getAdjacentNodes(Y));
        ArrayList<Node> adj = new ArrayList<Node>(adjSet);
        adj.remove(X);
        adj.remove(Y);
        int trueCounter = 0;
        int falseCounter = 0;
        for (int i = 0; i < this.dataSets.size(); ++i) {
            int[] choice;
            SublistGenerator gen = new SublistGenerator(adj.size(), this.depth);
            boolean possibleTwoCycle = false;
            while ((choice = gen.next()) != null) {
                List<Node> _adj = GraphUtils.asList(choice, adj);
                double[][][] _Z = new double[this.dataSets.size()][_adj.size()][];
                if (_adj.size() > 0) {
                    boolean rejected2;
                    for (int f = 0; f < _adj.size(); ++f) {
                        Node _z = _adj.get(f);
                        for (int g = 0; g < this.dataSets.size(); ++g) {
                            int column = this.dataSets.get(0).getColumn(_z);
                            _Z[g][f] = this.data[g][column];
                        }
                    }
                    double pc = this.partialCorrelation(x[i], y[i], _Z[i], x[i], Double.NEGATIVE_INFINITY);
                    double pc1 = this.partialCorrelation(x[i], y[i], _Z[i], x[i], 0.0);
                    double pc2 = this.partialCorrelation(x[i], y[i], _Z[i], y[i], 0.0);
                    int nc = StatUtils.getRows(x[i], Double.NEGATIVE_INFINITY, 1.0).size();
                    int nc1 = StatUtils.getRows(x[i], 0.0, 1.0).size();
                    int nc2 = StatUtils.getRows(y[i], 0.0, 1.0).size();
                    double z = 0.5 * (FastMath.log(1.0 + pc) - FastMath.log(1.0 - pc));
                    double z1 = 0.5 * (FastMath.log(1.0 + pc1) - FastMath.log(1.0 - pc1));
                    double z2 = 0.5 * (FastMath.log(1.0 + pc2) - FastMath.log(1.0 - pc2));
                    double zv1 = (z - z1) / FastMath.sqrt(1.0 / ((double)nc - 3.0) + 1.0 / ((double)nc1 - 3.0));
                    double zv2 = (z - z2) / FastMath.sqrt(1.0 / ((double)nc - 3.0) + 1.0 / ((double)nc2 - 3.0));
                    boolean rejected1 = FastMath.abs(zv1) > this.cutoff;
                    boolean bl = rejected2 = FastMath.abs(zv2) > this.cutoff;
                    if (zv1 < 0.0 && zv2 > 0.0 && rejected1) {
                        possibleTwoCycle = true;
                    } else if (zv1 > 0.0 && zv2 < 0.0 && rejected2) {
                        possibleTwoCycle = true;
                    } else if (rejected1 && rejected2) {
                        possibleTwoCycle = true;
                    }
                } else {
                    boolean rejected2;
                    double[][] _emptyZ = new double[0][0];
                    double pc = this.partialCorrelation(x[i], y[i], _emptyZ, x[i], Double.NEGATIVE_INFINITY);
                    double pc1 = this.partialCorrelation(x[i], y[i], _emptyZ, x[i], 0.0);
                    double pc2 = this.partialCorrelation(x[i], y[i], _emptyZ, y[i], 0.0);
                    int nc = StatUtils.getRows(x[i], Double.NEGATIVE_INFINITY, 1.0).size();
                    int nc1 = StatUtils.getRows(x[i], 0.0, 1.0).size();
                    int nc2 = StatUtils.getRows(y[i], 0.0, 1.0).size();
                    double z = 0.5 * (FastMath.log(1.0 + pc) - FastMath.log(1.0 - pc));
                    double z1 = 0.5 * (FastMath.log(1.0 + pc1) - FastMath.log(1.0 - pc1));
                    double z2 = 0.5 * (FastMath.log(1.0 + pc2) - FastMath.log(1.0 - pc2));
                    double zv1 = (z - z1) / FastMath.sqrt(1.0 / ((double)nc - 3.0) + 1.0 / ((double)nc1 - 3.0));
                    double zv2 = (z - z2) / FastMath.sqrt(1.0 / ((double)nc - 3.0) + 1.0 / ((double)nc2 - 3.0));
                    boolean rejected1 = FastMath.abs(zv1) > this.cutoff;
                    boolean bl = rejected2 = FastMath.abs(zv2) > this.cutoff;
                    if (zv1 < 0.0 && zv2 > 0.0 && rejected1) {
                        possibleTwoCycle = true;
                    } else if (zv1 > 0.0 && zv2 < 0.0 && rejected2) {
                        possibleTwoCycle = true;
                    } else if (rejected1 && rejected2) {
                        possibleTwoCycle = true;
                    }
                }
                if (possibleTwoCycle) continue;
                break;
            }
            if (possibleTwoCycle) {
                ++trueCounter;
                continue;
            }
            ++falseCounter;
        }
        return trueCounter > falseCounter;
    }

    private boolean leftright(double[][] x, double[][] y) {
        double lrSum = 0.0;
        for (int i = 0; i < this.dataSets.size(); ++i) {
            double left = MultiFaskV1.cu(x[i], y[i], x[i]) / FastMath.sqrt(MultiFaskV1.cu(x[i], x[i], x[i]) * MultiFaskV1.cu(y[i], y[i], x[i]));
            double right = MultiFaskV1.cu(x[i], y[i], y[i]) / FastMath.sqrt(MultiFaskV1.cu(x[i], x[i], y[i]) * MultiFaskV1.cu(y[i], y[i], y[i]));
            double lr = left - right;
            double r = StatUtils.correlation(x[i], y[i]);
            double sx = StatUtils.skewness(x[i]);
            double sy = StatUtils.skewness(y[i]);
            lr *= FastMath.signum(r *= FastMath.signum(sx) * FastMath.signum(sy));
            if (r < this.getDelta()) {
                lr *= -1.0;
            }
            lrSum += lr;
        }
        return lrSum > 0.0;
    }

    private static double cu(double[] x, double[] y, double[] condition) {
        double exy = 0.0;
        int n = 0;
        for (int k = 0; k < x.length; ++k) {
            if (!(condition[k] > 0.0)) continue;
            exy += x[k] * y[k];
            ++n;
        }
        return exy / (double)n;
    }

    private double partialCorrelation(double[] x, double[] y, double[][] z, double[] condition, double threshold) throws SingularMatrixException {
        double[][] cv = StatUtils.covMatrix(x, y, z, condition, threshold, 1.0);
        Matrix m = new Matrix(cv).transpose();
        return StatUtils.partialCorrelation(m);
    }

    private void setCutoff(double alpha) {
        if (alpha < 0.0 || alpha > 1.0) {
            throw new IllegalArgumentException("Significance out of range: " + alpha);
        }
        this.cutoff = StatUtils.getZForAlpha(alpha);
    }

    private boolean knowledgeOrients(Node left, Node right) {
        return this.knowledge.isForbidden(right.getName(), left.getName()) || this.knowledge.isRequired(left.getName(), right.getName());
    }

    public Graph getExternalGraph() {
        return this.externalGraph;
    }

    public void setExternalGraph(Graph externalGraph) {
        this.externalGraph = externalGraph;
    }

    public double getDelta() {
        return this.delta;
    }

    public void setDelta(double delta) {
        this.delta = delta;
    }
}

