/*
 * Decompiled with CFR 0.152.
 */
package edu.cmu.tetrad.search;

import cern.colt.list.DoubleArrayList;
import cern.jet.stat.Descriptive;
import edu.cmu.tetrad.data.AndersonDarlingTest;
import edu.cmu.tetrad.data.DataSet;
import edu.cmu.tetrad.graph.Edge;
import edu.cmu.tetrad.graph.EdgeListGraph;
import edu.cmu.tetrad.graph.Edges;
import edu.cmu.tetrad.graph.Graph;
import edu.cmu.tetrad.graph.GraphUtils;
import edu.cmu.tetrad.graph.Node;
import edu.cmu.tetrad.regression.Regression;
import edu.cmu.tetrad.regression.RegressionDataset;
import edu.cmu.tetrad.regression.RegressionResult;
import edu.cmu.tetrad.search.MeekRules;
import edu.cmu.tetrad.util.RandomUtil;
import edu.cmu.tetrad.util.StatUtils;
import edu.cmu.tetrad.util.SublistGenerator;
import edu.cmu.tetrad.util.TetradLogger;
import edu.cmu.tetrad.util.Vector;
import java.util.ArrayList;
import java.util.Iterator;
import java.util.List;
import java.util.Set;
import java.util.TreeMap;
import org.apache.commons.math3.util.FastMath;

public class Lofs {
    private final Graph CPDAG;
    private final List<DataSet> dataSets;
    private double alpha = 0.05;
    private final ArrayList<Regression> regressions;
    private final List<Node> variables;
    private final boolean strongR2;
    private final boolean meekDone;
    private final boolean meanCenterResiduals;
    private Score score = Score.andersonDarling;

    public Lofs(Graph CPDAG, List<DataSet> dataSets, boolean strongR2, boolean meekDone, boolean meanCenterResiduals) throws IllegalArgumentException {
        this.strongR2 = strongR2;
        this.meekDone = meekDone;
        this.meanCenterResiduals = meanCenterResiduals;
        if (CPDAG == null) {
            throw new IllegalArgumentException("CPDAG must be specified.");
        }
        if (dataSets == null) {
            throw new IllegalArgumentException("Data set must be specified.");
        }
        this.CPDAG = CPDAG;
        this.dataSets = dataSets;
        this.regressions = new ArrayList();
        this.variables = dataSets.get(0).getVariables();
        for (DataSet dataSet : dataSets) {
            this.regressions.add(new RegressionDataset(dataSet));
        }
    }

    public Graph orient() {
        Graph skeleton = GraphUtils.undirectedGraph(this.getCPDAG());
        EdgeListGraph graph = new EdgeListGraph(skeleton.getNodes());
        List<Node> nodes = skeleton.getNodes();
        if (this.isR1Done()) {
            this.ruleR1(skeleton, graph, nodes);
        }
        for (Edge edge : skeleton.getEdges()) {
            if (graph.isAdjacentTo(edge.getNode1(), edge.getNode2())) continue;
            graph.addUndirectedEdge(edge.getNode1(), edge.getNode2());
        }
        if (this.isR2Done()) {
            this.ruleR2(skeleton, graph);
        }
        if (this.isMeekDone()) {
            new MeekRules().orientImplied(graph);
        }
        return graph;
    }

    private void ruleR1(Graph skeleton, Graph graph, List<Node> nodes) {
        for (Node node : nodes) {
            double score;
            int[] choice;
            TreeMap<Double, String> scoreReports = new TreeMap<Double, String>();
            List<Node> adj = skeleton.getAdjacentNodes(node);
            SublistGenerator gen = new SublistGenerator(adj.size(), adj.size());
            double maxScore = Double.NEGATIVE_INFINITY;
            List<Node> parents = null;
            while ((choice = gen.next()) != null) {
                List<Node> _parents = GraphUtils.asList(choice, adj);
                score = this.score(node, _parents);
                scoreReports.put(-score, _parents.toString());
                if (!(score > maxScore)) continue;
                maxScore = score;
                parents = _parents;
            }
            Iterator<Object> iterator = scoreReports.keySet().iterator();
            while (iterator.hasNext()) {
                score = (Double)iterator.next();
                TetradLogger.getInstance().log("score", "For " + node + " parents = " + (String)scoreReports.get(score) + " score = " + -score);
            }
            TetradLogger.getInstance().log("score", "");
            if (parents == null || this.normal(node, parents)) continue;
            for (Node _node : adj) {
                Edge parentEdge;
                if (!parents.contains(_node) || graph.containsEdge(parentEdge = Edges.directedEdge(_node, node))) continue;
                graph.addEdge(parentEdge);
            }
        }
    }

    private void ruleR2(Graph skeleton, Graph graph) {
        Set<Edge> edgeList1 = skeleton.getEdges();
        for (Edge adj : edgeList1) {
            Node x = adj.getNode1();
            Node y = adj.getNode2();
            if (!this.isR2Orient2Cycles() && this.isTwoCycle(graph, x, y) || !this.isTwoCycle(graph, x, y) && !this.isUndirected(graph, x, y)) continue;
            this.resolveOneEdgeMax(graph, x, y, this.isStrongR2());
        }
    }

    private boolean isTwoCycle(Graph graph, Node x, Node y) {
        List<Edge> edges = graph.getEdges(x, y);
        return edges.size() == 2;
    }

    private boolean isUndirected(Graph graph, Node x, Node y) {
        List<Edge> edges = graph.getEdges(x, y);
        if (edges.size() == 1) {
            Edge edge = graph.getEdge(x, y);
            return Edges.isUndirectedEdge(edge);
        }
        return false;
    }

    private boolean normal(Node node, List<Node> parents) {
        if (this.getAlpha() > 0.999) {
            return false;
        }
        return this.pValue(node, parents) > this.getAlpha();
    }

    private void resolveOneEdgeMax(Graph graph, Node x, Node y, boolean strong) {
        int[] choicex;
        if (RandomUtil.getInstance().nextDouble() > 0.5) {
            Node temp = x;
            x = y;
            y = temp;
        }
        TetradLogger.getInstance().log("info", "\nEDGE " + x + " --- " + y);
        TreeMap<Double, String> scoreReports = new TreeMap<Double, String>();
        List<Node> neighborsx = graph.getAdjacentNodes(x);
        neighborsx.remove(y);
        double max = Double.NEGATIVE_INFINITY;
        boolean left = false;
        boolean right = false;
        SublistGenerator genx = new SublistGenerator(neighborsx.size(), neighborsx.size());
        while ((choicex = genx.next()) != null) {
            int[] choicey;
            List<Node> condxMinus = GraphUtils.asList(choicex, neighborsx);
            ArrayList<Node> condxPlus = new ArrayList<Node>(condxMinus);
            condxPlus.add(y);
            double xPlus = this.score(x, condxPlus);
            double xMinus = this.score(x, condxMinus);
            List<Node> neighborsy = graph.getAdjacentNodes(y);
            neighborsy.remove(x);
            SublistGenerator geny = new SublistGenerator(neighborsy.size(), neighborsy.size());
            while ((choicey = geny.next()) != null) {
                String s;
                double score;
                List<Node> condyMinus = GraphUtils.asList(choicey, neighborsy);
                ArrayList<Node> condyPlus = new ArrayList<Node>(condyMinus);
                condyPlus.add(x);
                double yPlus = this.score(y, condyPlus);
                double yMinus = this.score(y, condyMinus);
                if (this.normal(y, condyPlus) || this.normal(x, condxMinus) || this.normal(x, condxPlus) || this.normal(y, condyMinus)) continue;
                double delta = 0.0;
                if (strong) {
                    if (yPlus <= xPlus + 0.0 && xMinus <= yMinus + 0.0) {
                        score = this.combinedScore(xPlus, yMinus);
                        if (yPlus <= yMinus + 0.0 && xMinus <= xPlus + 0.0) {
                            s = "\nStrong " + y + "->" + x + " " + score + "\n   Parents(" + x + ") = " + condxMinus + "\n   Parents(" + y + ") = " + condyMinus;
                            scoreReports.put(-score, s);
                            if (!(score > max)) continue;
                            max = score;
                            left = true;
                            right = false;
                            continue;
                        }
                        s = "\nNo directed edge " + x + "--" + y + " " + score + "\n   Parents(" + x + ") = " + condxMinus + "\n   Parents(" + y + ") = " + condyMinus;
                        scoreReports.put(-score, s);
                        continue;
                    }
                    if (xPlus <= yPlus + 0.0 && yMinus <= xMinus + 0.0) {
                        score = this.combinedScore(yPlus, xMinus);
                        if (yMinus <= yPlus + 0.0 && xPlus <= xMinus + 0.0) {
                            s = "\nStrong " + x + "->" + y + " " + score + "\n   Parents(" + x + ") = " + condxMinus + "\n   Parents(" + y + ") = " + condyMinus;
                            scoreReports.put(-score, s);
                            if (!(score > max)) continue;
                            max = score;
                            left = false;
                            right = true;
                            continue;
                        }
                        s = "\nNo directed edge " + x + "--" + y + " " + score + "\n   Parents(" + x + ") = " + condxMinus + "\n   Parents(" + y + ") = " + condyMinus;
                        scoreReports.put(-score, s);
                        continue;
                    }
                    if (yPlus <= xPlus + 0.0 && yMinus <= xMinus + 0.0) {
                        score = this.combinedScore(yPlus, xMinus);
                        s = "\nNo directed edge " + x + "--" + y + " " + score + "\n   Parents(" + x + ") = " + condxMinus + "\n   Parents(" + y + ") = " + condyMinus;
                        scoreReports.put(-score, s);
                        continue;
                    }
                    if (!(xPlus <= yPlus + 0.0) || !(xMinus <= yMinus + 0.0)) continue;
                    score = this.combinedScore(yPlus, xMinus);
                    s = "\nNo directed edge " + x + "--" + y + " " + score + "\n   Parents(" + x + ") = " + condxMinus + "\n   Parents(" + y + ") = " + condyMinus;
                    scoreReports.put(-score, s);
                    continue;
                }
                if (yPlus <= xPlus + 0.0 && xMinus <= yMinus + 0.0) {
                    score = this.combinedScore(xPlus, yMinus);
                    s = "\nWeak " + y + "->" + x + " " + score + "\n   Parents(" + x + ") = " + condxMinus + "\n   Parents(" + y + ") = " + condyMinus;
                    scoreReports.put(-score, s);
                    if (!(score > max)) continue;
                    max = score;
                    left = true;
                    right = false;
                    continue;
                }
                if (xPlus <= yPlus + 0.0 && yMinus <= xMinus + 0.0) {
                    score = this.combinedScore(yPlus, xMinus);
                    s = "\nWeak " + x + "->" + y + " " + score + "\n   Parents(" + x + ") = " + condxMinus + "\n   Parents(" + y + ") = " + condyMinus;
                    scoreReports.put(-score, s);
                    if (!(score > max)) continue;
                    max = score;
                    left = false;
                    right = true;
                    continue;
                }
                if (yPlus <= xPlus + 0.0 && yMinus <= xMinus + 0.0) {
                    score = this.combinedScore(yPlus, xMinus);
                    s = "\nNo directed edge " + x + "--" + y + " " + score + "\n   Parents(" + x + ") = " + condxMinus + "\n   Parents(" + y + ") = " + condyMinus;
                    scoreReports.put(-score, s);
                    continue;
                }
                if (!(xPlus <= yPlus + 0.0) || !(xMinus <= yMinus + 0.0)) continue;
                score = this.combinedScore(yPlus, xMinus);
                s = "\nNo directed edge " + x + "--" + y + " " + score + "\n   Parents(" + x + ") = " + condxMinus + "\n   Parents(" + y + ") = " + condyMinus;
                scoreReports.put(-score, s);
            }
        }
        Iterator iterator = scoreReports.keySet().iterator();
        while (iterator.hasNext()) {
            double score = (Double)iterator.next();
            TetradLogger.getInstance().log("info", (String)scoreReports.get(score));
        }
        graph.removeEdges(x, y);
        if (left) {
            graph.addDirectedEdge(y, x);
        }
        if (right) {
            graph.addDirectedEdge(x, y);
        }
        if (!graph.isAdjacentTo(x, y)) {
            graph.addUndirectedEdge(x, y);
        }
    }

    private double combinedScore(double score1, double score2) {
        return score1 + score2;
    }

    private double score(Node y, List<Node> parents) {
        if (this.score == Score.andersonDarling) {
            return this.andersonDarlingPASquareStar(y, parents);
        }
        if (this.score == Score.kurtosis) {
            return FastMath.abs(StatUtils.kurtosis(this.residual(y, parents)));
        }
        if (this.score == Score.skew) {
            return FastMath.abs(StatUtils.skewness(this.residual(y, parents)));
        }
        if (this.score == Score.fifthMoment) {
            return FastMath.abs(StatUtils.standardizedFifthMoment(this.residual(y, parents)));
        }
        if (this.score == Score.absoluteValue) {
            return this.localScoreA(y, parents);
        }
        throw new IllegalStateException();
    }

    private double localScoreA(Node node, List<Node> parents) {
        double score = 0.0;
        ArrayList<Double> _residuals = new ArrayList<Double>();
        Node target = this.getVariable(this.variables, node.getName());
        ArrayList<Node> regressors = new ArrayList<Node>();
        for (Node _regressor : parents) {
            Node variable = this.getVariable(this.variables, _regressor.getName());
            regressors.add(variable);
        }
        block1: for (int m = 0; m < this.dataSets.size(); ++m) {
            RegressionResult result = this.regressions.get(m).regress(target, regressors);
            Vector residualsSingleDataset = result.getResiduals();
            for (int h = 0; h < residualsSingleDataset.size(); ++h) {
                if (Double.isNaN(residualsSingleDataset.get(h))) continue block1;
            }
            DoubleArrayList _residualsSingleDataset = new DoubleArrayList(residualsSingleDataset.toArray());
            double mean = Descriptive.mean(_residualsSingleDataset);
            double std = Descriptive.standardDeviation(Descriptive.variance(_residualsSingleDataset.size(), Descriptive.sum(_residualsSingleDataset), Descriptive.sumOfSquares(_residualsSingleDataset)));
            for (int i2 = 0; i2 < _residualsSingleDataset.size(); ++i2) {
                _residualsSingleDataset.set(i2, (_residualsSingleDataset.get(i2) - mean) / std);
            }
            for (int k = 0; k < _residualsSingleDataset.size(); ++k) {
                _residuals.add(_residualsSingleDataset.get(k));
            }
        }
        double[] _f = new double[_residuals.size()];
        for (int k = 0; k < _residuals.size(); ++k) {
            _f[k] = (Double)_residuals.get(k);
        }
        DoubleArrayList f = new DoubleArrayList(_f);
        for (int k = 0; k < _residuals.size(); ++k) {
            f.set(k, FastMath.abs(f.get(k)));
        }
        double _mean = Descriptive.mean(f);
        double diff = _mean - FastMath.sqrt(0.6366197723675814);
        return score += diff * diff;
    }

    private double andersonDarlingPASquareStar(Node node, List<Node> parents) {
        ArrayList<Double> _residuals = new ArrayList<Double>();
        Node target = this.getVariable(this.variables, node.getName());
        ArrayList<Node> regressors = new ArrayList<Node>();
        for (Node _regressor : parents) {
            Node variable = this.getVariable(this.variables, _regressor.getName());
            regressors.add(variable);
        }
        block1: for (int m = 0; m < this.dataSets.size(); ++m) {
            RegressionResult result = this.regressions.get(m).regress(target, regressors);
            Vector residualsSingleDataset = result.getResiduals();
            for (int h = 0; h < residualsSingleDataset.size(); ++h) {
                if (Double.isNaN(residualsSingleDataset.get(h))) continue block1;
            }
            DoubleArrayList _residualsSingleDataset = new DoubleArrayList(residualsSingleDataset.toArray());
            double mean = Descriptive.mean(_residualsSingleDataset);
            for (int i2 = 0; i2 < _residualsSingleDataset.size(); ++i2) {
                if (!this.isMeanCenterResiduals()) continue;
                _residualsSingleDataset.set(i2, _residualsSingleDataset.get(i2) - mean);
            }
            for (int k = 0; k < _residualsSingleDataset.size(); ++k) {
                _residuals.add(_residualsSingleDataset.get(k));
            }
        }
        double[] _f = new double[_residuals.size()];
        for (int k = 0; k < _residuals.size(); ++k) {
            _f[k] = (Double)_residuals.get(k);
        }
        return new AndersonDarlingTest(_f).getASquaredStar();
    }

    private double pValue(Node node, List<Node> parents) {
        ArrayList<Double> _residuals = new ArrayList<Double>();
        Node target = this.getVariable(this.variables, node.getName());
        ArrayList<Node> regressors = new ArrayList<Node>();
        for (Node _regressor : parents) {
            Node variable = this.getVariable(this.variables, _regressor.getName());
            regressors.add(variable);
        }
        block1: for (int m = 0; m < this.dataSets.size(); ++m) {
            RegressionResult result = this.regressions.get(m).regress(target, regressors);
            Vector residualsSingleDataset = result.getResiduals();
            for (int h = 0; h < residualsSingleDataset.size(); ++h) {
                if (Double.isNaN(residualsSingleDataset.get(h))) continue block1;
            }
            DoubleArrayList _residualsSingleDataset = new DoubleArrayList(residualsSingleDataset.toArray());
            double mean = Descriptive.mean(_residualsSingleDataset);
            for (int i2 = 0; i2 < _residualsSingleDataset.size(); ++i2) {
                if (!this.isMeanCenterResiduals()) continue;
                _residualsSingleDataset.set(i2, _residualsSingleDataset.get(i2) - mean);
            }
            for (int k = 0; k < _residualsSingleDataset.size(); ++k) {
                _residuals.add(_residualsSingleDataset.get(k));
            }
        }
        double[] _f = new double[_residuals.size()];
        for (int k = 0; k < _residuals.size(); ++k) {
            _f[k] = (Double)_residuals.get(k);
        }
        return new AndersonDarlingTest(_f).getP();
    }

    private double[] residual(Node node, List<Node> parents) {
        ArrayList<Double> _residuals = new ArrayList<Double>();
        Node target = this.getVariable(this.variables, node.getName());
        ArrayList<Node> regressors = new ArrayList<Node>();
        for (Node _regressor : parents) {
            Node variable = this.getVariable(this.variables, _regressor.getName());
            regressors.add(variable);
        }
        block1: for (int m = 0; m < this.dataSets.size(); ++m) {
            RegressionResult result = this.regressions.get(m).regress(target, regressors);
            Vector residualsSingleDataset = result.getResiduals();
            for (int h = 0; h < residualsSingleDataset.size(); ++h) {
                if (Double.isNaN(residualsSingleDataset.get(h))) continue block1;
            }
            DoubleArrayList _residualsSingleDataset = new DoubleArrayList(residualsSingleDataset.toArray());
            double mean = Descriptive.mean(_residualsSingleDataset);
            for (int i2 = 0; i2 < _residualsSingleDataset.size(); ++i2) {
                _residualsSingleDataset.set(i2, _residualsSingleDataset.get(i2) - mean);
            }
            for (int k = 0; k < _residualsSingleDataset.size(); ++k) {
                _residuals.add(_residualsSingleDataset.get(k));
            }
        }
        double[] _f = new double[_residuals.size()];
        for (int k = 0; k < _residuals.size(); ++k) {
            _f[k] = (Double)_residuals.get(k);
        }
        return _f;
    }

    public double getAlpha() {
        return this.alpha;
    }

    public void setAlpha(double alpha) {
        if (alpha < 0.0 || alpha > 1.0) {
            throw new IllegalArgumentException("Alpha is in range [0, 1]");
        }
        this.alpha = alpha;
    }

    private Graph getCPDAG() {
        return this.CPDAG;
    }

    private Node getVariable(List<Node> variables, String name) {
        for (Node node : variables) {
            if (!name.equals(node.getName())) continue;
            return node;
        }
        return null;
    }

    public boolean isR1Done() {
        return true;
    }

    public boolean isR2Done() {
        return true;
    }

    public boolean isMeekDone() {
        return this.meekDone;
    }

    public boolean isStrongR2() {
        return this.strongR2;
    }

    public boolean isR2Orient2Cycles() {
        return true;
    }

    public Score getScore() {
        return this.score;
    }

    public void setScore(Score score) {
        if (score == null) {
            throw new NullPointerException();
        }
        this.score = score;
    }

    public boolean isMeanCenterResiduals() {
        return this.meanCenterResiduals;
    }

    public static enum Score {
        andersonDarling,
        skew,
        kurtosis,
        fifthMoment,
        absoluteValue,
        exp,
        expUnstandardized,
        expUnstandardizedInverted,
        other,
        logcosh,
        entropy;

    }
}

