/*
 * Decompiled with CFR 0.152.
 */
package edu.cmu.tetrad.predict;

import cern.colt.matrix.DoubleMatrix2D;
import edu.cmu.tetrad.data.DataSet;
import edu.cmu.tetrad.data.DataUtils;
import edu.cmu.tetrad.graph.Dag;
import edu.cmu.tetrad.graph.Edge;
import edu.cmu.tetrad.graph.EdgeListGraph;
import edu.cmu.tetrad.graph.Edges;
import edu.cmu.tetrad.graph.Endpoint;
import edu.cmu.tetrad.graph.Graph;
import edu.cmu.tetrad.graph.GraphNode;
import edu.cmu.tetrad.graph.Node;
import edu.cmu.tetrad.predict.ManipulatedVariable;
import edu.cmu.tetrad.predict.OrderingGenerator;
import edu.cmu.tetrad.sem.SemEstimator;
import edu.cmu.tetrad.sem.SemIm;
import edu.cmu.tetrad.sem.SemPm;
import edu.cmu.tetrad.util.MatrixUtils;
import edu.cmu.tetrad.util.TetradSerializable;
import java.io.IOException;
import java.io.ObjectInputStream;
import java.util.Collections;
import java.util.HashSet;
import java.util.Iterator;
import java.util.LinkedList;
import java.util.List;
import java.util.Set;

public final class Prediction
implements TetradSerializable {
    static final long serialVersionUID = 23L;
    private Graph pag;
    private DataSet dataSet;
    private ManipulatedVariable manipulatedVariable;
    private Set zSet;
    private Node yNode;

    public Prediction(Graph pag, DataSet dataContinuous, ManipulatedVariable manipulated, Node pred, Set condSet) {
        if (pag == null) {
            throw new IllegalArgumentException("PAG must not be null");
        }
        if (dataContinuous == null) {
            throw new IllegalArgumentException("PAG must not be null");
        }
        this.pag = pag;
        this.dataSet = dataContinuous;
        this.manipulatedVariable = manipulated;
        this.zSet = condSet;
        this.yNode = pred;
    }

    public static Prediction serializableInstance() {
        return new Prediction(EdgeListGraph.serializableInstance(), DataUtils.continuousSerializableInstance(), ManipulatedVariable.serializableInstance(), GraphNode.serializableInstance(), new HashSet());
    }

    public Graph getPag() {
        return this.pag;
    }

    public double predict() {
        if (this.zSet.contains(this.yNode) || this.manipulatedVariable.equals(this.yNode)) {
            throw new IllegalStateException("yNode,zSet not disjoint or X = yNode");
        }
        Set<Node> Yset = Collections.singleton(this.yNode);
        OrderingGenerator O = new OrderingGenerator(this.pag);
        List<List<Node>> orders = O.getOrders();
        for (LinkedList linkedList : orders) {
            int i;
            Dag currImap = this.genImap(linkedList);
            if (!this.invarianceTest(currImap)) continue;
            HashSet<Node> subgraphSet = new HashSet<Node>();
            subgraphSet.addAll(this.IV(Yset, this.zSet, currImap));
            subgraphSet.addAll(this.IP(Yset, this.zSet, currImap));
            subgraphSet.add(this.yNode);
            subgraphSet.addAll(this.getParentsSet(currImap, subgraphSet));
            subgraphSet.add(this.manipulatedVariable.getNode());
            LinkedList<Edge> allEdges = new LinkedList<Edge>(currImap.getEdges());
            while (allEdges.size() > 0) {
                Edge e = allEdges.removeFirst();
                if (subgraphSet.contains(e.getNode1()) && subgraphSet.contains(e.getNode2())) continue;
                currImap.removeEdge(e);
            }
            SemIm est = this.estimatedSemIm(currImap, this.dataSet);
            System.out.println("subgraph estimated");
            List<Node> Xparents = currImap.getParents(this.manipulatedVariable.getNode());
            while (Xparents.size() > 0) {
                try {
                    est.setParamValue(this.manipulatedVariable.getNode(), Xparents.remove(0), 0.0);
                }
                catch (IllegalArgumentException e) {}
            }
            Node exogForX = this.manipulatedVariable.getNode();
            est.setParamValue(exogForX, exogForX, this.manipulatedVariable.getVariance());
            DoubleMatrix2D implCovarC = est.getImplCovar();
            double[][] newCovMatrix = implCovarC.toArray();
            Set YuZ = this.zSet;
            YuZ.add(this.yNode);
            List<Node> orderedVars = est.getVariableNodes();
            LinkedList<Node> YZonly = new LinkedList<Node>(orderedVars);
            YZonly.retainAll(YuZ);
            LinkedList<Integer> integers = new LinkedList<Integer>();
            for (int i2 = 0; i2 < newCovMatrix.length; ++i2) {
                integers.add(i2);
            }
            Iterator yz = YuZ.iterator();
            while (yz.hasNext()) {
                int currVarIndex = orderedVars.indexOf(yz.next());
                integers.remove(currVarIndex);
            }
            while (integers.size() > 0) {
                newCovMatrix = MatrixUtils.submatrix(newCovMatrix, (Integer)integers.removeFirst());
            }
            System.out.println("submatrix made");
            int Yindex = YZonly.indexOf(this.yNode);
            double[] Yrow = newCovMatrix[Yindex];
            double[] Ycol = MatrixUtils.transpose(newCovMatrix)[Yindex];
            double[][] ZsMatrix = MatrixUtils.submatrix(newCovMatrix, Yindex);
            double sigmaYY = Yrow[Yindex];
            double sigmaZY = 0.0;
            double sigmaYZ = 0.0;
            for (i = 0; i < Yrow.length; ++i) {
                if (i == Yindex) continue;
                sigmaZY += Yrow[i];
            }
            for (i = 0; i < Ycol.length; ++i) {
                if (i == Yindex) continue;
                sigmaYZ += Ycol[i];
            }
            double sigmaZZ = MatrixUtils.zSum(ZsMatrix);
            double covYcondZ = sigmaYY - sigmaYZ * (1.0 / sigmaZZ) * sigmaZY;
            return covYcondZ;
        }
        return Double.NaN;
    }

    private Dag genImap(List Ord) {
        Dag resultImap = new Dag();
        for (int i = 0; i < Ord.size(); ++i) {
            resultImap.addNode((Node)Ord.get(i));
        }
        for (Node n : Ord) {
            Set<Node> sp = this.definiteSP(n, Ord);
            for (Node parent : sp) {
                Edge e = new Edge(parent, n, Endpoint.TAIL, Endpoint.ARROW);
                resultImap.addEdge(e);
            }
        }
        return resultImap;
    }

    private boolean invarianceTest(Graph G) {
        Set<Node> Yset = Collections.singleton(this.yNode);
        Node x = this.manipulatedVariable.getNode();
        if (this.zSet.contains(x)) {
            return !this.possibly_IP(Yset, this.zSet, G).contains(x);
        }
        return !this.possibly_IV(Yset, this.zSet, G).contains(x);
    }

    private SemIm estimatedSemIm(Dag G, DataSet d) {
        SemPm sempm = new SemPm(G);
        SemEstimator est = new SemEstimator(d, sempm);
        est.estimate();
        return est.getEstimatedSem();
    }

    private Set<Node> possibly_IV(Set Y, Set Z, Graph G) {
        HashSet<Node> result = new HashSet<Node>();
        LinkedList<Node> poss = new LinkedList<Node>(G.getNodes());
        poss.removeAll(Z);
        HashSet YuZ = new HashSet(Y);
        YuZ.addAll(Z);
        while (poss.size() > 0) {
            Node v = poss.removeFirst();
            if (!this.possDConnectedToAltered(G, v, Y, Z, false) || !this.existsSemiDirectedPathFromToSet(G, v, YuZ)) continue;
            result.add(v);
        }
        return result;
    }

    private Set<Node> possibly_IP(Set Y, Set Z, Graph G) {
        HashSet<Node> result = new HashSet<Node>();
        for (Node v : Z) {
            HashSet ZminusV = new HashSet(Z);
            ZminusV.remove(v);
            if (!this.possDConnectedToAltered(G, v, Y, ZminusV, true)) continue;
            result.add(v);
        }
        return result;
    }

    private boolean existsSemiDirectedPathFromToSet(Graph G, Node node1, Set nodes2) {
        return this.existsSemiDirectedPathVisit(G, node1, nodes2, new LinkedList<Node>());
    }

    private boolean existsSemiDirectedPathVisit(Graph G, Node node1, Set nodes2, LinkedList<Node> path) {
        path.addLast(node1);
        Iterator<Edge> it = G.getEdges(node1).iterator();
        while (it.hasNext()) {
            Node child = Edges.traverseSemiDirected(node1, it.next());
            if (child == null) continue;
            if (nodes2.contains(child)) {
                return true;
            }
            if (path.contains(child) || !this.existsSemiDirectedPathVisit(G, child, nodes2, path)) continue;
            return true;
        }
        path.removeLast();
        return false;
    }

    private boolean possDConnectedToAltered(Graph G, Node node1, Set nodes2, Set condNodes, boolean special_flag) {
        LinkedList<Node> allNodes = new LinkedList<Node>(G.getNodes());
        int sz = allNodes.size();
        int[][] edgeStage = new int[sz][sz];
        int stage = 1;
        int n1x = allNodes.indexOf(node1);
        edgeStage[n1x][n1x] = 1;
        LinkedList<int[]> nextEdges = new LinkedList<int[]>();
        int[] temp1 = new int[]{n1x, n1x};
        nextEdges.add(temp1);
        while (true) {
            LinkedList<int[]> currEdges = nextEdges;
            nextEdges = new LinkedList();
            for (int i = 0; i < currEdges.size(); ++i) {
                int[] edge = (int[])currEdges.get(i);
                Node center = allNodes.get(edge[1]);
                LinkedList<Node> adj = new LinkedList<Node>(G.getAdjacentNodes(center));
                if (special_flag && ((Object)center).equals(node1)) {
                    adj.removeAll(G.getNodesInTo(center, Endpoint.TAIL));
                }
                for (int j = 0; j < adj.size(); ++j) {
                    Node Z;
                    Node Y;
                    Node X;
                    int testIndex = allNodes.indexOf(adj.get(j));
                    if (edgeStage[edge[1]][testIndex] != 0 || (!G.isDefNoncollider(X = allNodes.get(edge[0]), Y = allNodes.get(edge[1]), Z = allNodes.get(testIndex)) || condNodes.contains(Y)) && (!G.isDefCollider(X, Y, Z) || !this.possibleAncestorSet(G, Y, condNodes))) continue;
                    if (nodes2.contains(adj.get(j))) {
                        return true;
                    }
                    int[] nextEdge = new int[]{edge[1], testIndex};
                    nextEdges.add(nextEdge);
                    edgeStage[edge[1]][testIndex] = stage;
                    edgeStage[testIndex][edge[1]] = stage;
                }
            }
            if (nextEdges.size() == 0) break;
            ++stage;
        }
        return false;
    }

    private boolean possibleAncestorSet(Graph G, Node node1, Set nodes2) {
        Iterator it = nodes2.iterator();
        while (it.hasNext()) {
            if (!G.possibleAncestor(node1, (Node)it.next())) continue;
            return true;
        }
        return false;
    }

    private Set<Node> IV(Set<Node> Y, Set Z, Graph G) {
        HashSet<Node> result = new HashSet<Node>();
        LinkedList<Node> allNodes = new LinkedList<Node>(G.getNodes());
        LinkedList<Node> possV = new LinkedList<Node>();
        while (allNodes.size() > 0) {
            Node curr = allNodes.removeFirst();
            LinkedList<Node> zlist = new LinkedList<Node>(Z);
            for (Node nextY : Y) {
                if (!G.isDConnectedTo(curr, nextY, zlist)) continue;
                possV.add(curr);
            }
        }
        block2: while (possV.size() > 0) {
            Node v = (Node)possV.removeFirst();
            List<Node> desc = G.getDescendants(Collections.singletonList(v));
            for (Node n : desc) {
                if (!Y.contains(n) && !Z.contains(n)) continue;
                result.add(v);
                continue block2;
            }
        }
        return result;
    }

    private Set<Node> IP(Set<Node> Y, Set<Node> Z, Graph G) {
        HashSet<Node> result = new HashSet<Node>();
        Set<Node> ivYZ = this.IV(Y, Z, G);
        HashSet<Node> ivUy = new HashSet<Node>(ivYZ);
        ivUy.addAll(Y);
        Iterator<Node> i$ = Z.iterator();
        while (i$.hasNext()) {
            Node aZ;
            Node possW = aZ = i$.next();
            List<Node> parents = G.getParents(possW);
            parents.retainAll(ivUy);
            if (parents.size() == 0) continue;
            result.add(possW);
        }
        return result;
    }

    private Set<Node> definiteSP(Node X, List Ord) {
        return this.generalSP(X, Ord, true);
    }

    private Set<Node> generalSP(Node X, List<Node> Ord, boolean defT_possF) {
        HashSet<Node> result = new HashSet<Node>();
        LinkedList<Node> reachable = new LinkedList<Node>();
        int xindex = Ord.indexOf(X);
        HashSet<Node> beforeX = new HashSet<Node>(Ord.subList(0, xindex));
        List<Node> step1 = this.pag.getAdjacentNodes(X);
        step1.retainAll(beforeX);
        result.addAll(step1);
        reachable.addAll(step1);
        Node one = X;
        while (!reachable.isEmpty()) {
            Node currTwo = (Node)reachable.removeFirst();
            if (beforeX.contains(currTwo)) {
                List<Node> adj = this.pag.getAdjacentNodes(currTwo);
                while (!adj.isEmpty()) {
                    Node possThree = adj.remove(0);
                    if (!beforeX.contains(possThree) || (defT_possF || this.pag.isDefNoncollider(one, currTwo, possThree)) && (!defT_possF || !this.pag.isDefCollider(one, currTwo, possThree))) continue;
                    result.add(possThree);
                    reachable.add(possThree);
                }
            } else {
                System.out.println("This should never happen");
            }
            one = currTwo;
        }
        return result;
    }

    private Set<Node> getParentsSet(Graph G, Set<Node> N) {
        HashSet<Node> result = new HashSet<Node>();
        for (Node aN : N) {
            result.addAll(G.getParents(aN));
        }
        return result;
    }

    private void readObject(ObjectInputStream s) throws IOException, ClassNotFoundException {
        s.defaultReadObject();
        if (this.pag == null) {
            throw new NullPointerException();
        }
        if (this.dataSet == null) {
            throw new NullPointerException();
        }
    }
}

