/*
 * Decompiled with CFR 0.152.
 */
package edu.cmu.tetrad.sem;

import cern.colt.matrix.impl.DenseDoubleMatrix2D;
import edu.cmu.tetrad.graph.Graph;
import edu.cmu.tetrad.graph.Node;
import edu.cmu.tetrad.graph.NodeType;
import edu.cmu.tetrad.graph.SemGraph;
import edu.cmu.tetrad.sem.SemEvidence;
import edu.cmu.tetrad.sem.SemIm;
import edu.cmu.tetrad.util.Matrix;
import edu.cmu.tetrad.util.TetradSerializable;
import edu.cmu.tetrad.util.Vector;
import java.util.ArrayList;
import java.util.List;

public class SemUpdater
implements TetradSerializable {
    private static final long serialVersionUID = 23L;
    private final SemIm semIm;
    private SemEvidence evidence;

    public SemUpdater(SemIm semIm) {
        if (semIm == null) {
            throw new NullPointerException();
        }
        this.semIm = semIm;
        SemEvidence evidence = new SemEvidence(this.semIm);
        this.setEvidence(evidence);
    }

    public static SemUpdater serializableInstance() {
        return new SemUpdater(SemIm.serializableInstance());
    }

    public SemEvidence getEvidence() {
        return this.evidence;
    }

    public void setEvidence(SemEvidence evidence) {
        if (evidence == null) {
            throw new NullPointerException();
        }
        this.evidence = evidence;
    }

    public SemIm getSemIm() {
        return this.semIm;
    }

    public SemIm getUpdatedSemIm() {
        int i;
        SemIm manipulatedSemIm = this.getManipulatedSemIm();
        Vector means = new Vector(manipulatedSemIm.getVariableNodes().size());
        for (int i2 = 0; i2 < means.size(); ++i2) {
            means.set(i2, manipulatedSemIm.getMean(manipulatedSemIm.getVariableNodes().get(i2)));
        }
        Matrix implcov = manipulatedSemIm.getImplCovar(true);
        SemEvidence evidence = this.getEvidence();
        ArrayList<Node> nodesInEvidence = new ArrayList<Node>(evidence.getNodesInEvidence());
        ArrayList<Node> XVars = new ArrayList<Node>(evidence.getNodesInEvidence());
        ArrayList<Node> YVars = new ArrayList<Node>(manipulatedSemIm.getVariableNodes());
        YVars.removeAll(nodesInEvidence);
        int[] xIndices = new int[XVars.size()];
        int[] yIndices = new int[YVars.size()];
        for (i = 0; i < XVars.size(); ++i) {
            xIndices[i] = manipulatedSemIm.getVariableNodes().indexOf(XVars.get(i));
        }
        for (i = 0; i < YVars.size(); ++i) {
            yIndices[i] = manipulatedSemIm.getVariableNodes().indexOf(YVars.get(i));
        }
        Matrix covyx = implcov.getSelection(yIndices, xIndices);
        Matrix varx = implcov.getSelection(xIndices, xIndices);
        Vector EX = means.viewSelection(xIndices);
        Vector EY = means.viewSelection(yIndices);
        Vector X = new Vector(nodesInEvidence.size());
        for (int i3 = 0; i3 < nodesInEvidence.size(); ++i3) {
            int j = evidence.getNodeIndex((Node)nodesInEvidence.get(i3));
            X.set(i3, evidence.getProposition().getValue(j));
        }
        Vector xminusex = X.minus(EX);
        Vector mu = new Vector(manipulatedSemIm.getVariableNodes().size());
        DenseDoubleMatrix2D sigma2 = new DenseDoubleMatrix2D(manipulatedSemIm.getErrCovar().toArray());
        if (xminusex.size() == 0) {
            mu = new Vector(means.toArray());
        } else {
            int i4;
            Vector times = covyx.times(varx.inverse()).times(xminusex);
            Vector YHatX = EY.plus(times);
            for (i4 = 0; i4 < xIndices.length; ++i4) {
                mu.set(xIndices[i4], X.get(i4));
            }
            for (i4 = 0; i4 < yIndices.length; ++i4) {
                mu.set(yIndices[i4], YHatX.get(i4));
            }
        }
        return manipulatedSemIm.updatedIm(new Matrix(sigma2.toArray()), mu);
    }

    public Graph getManipulatedGraph() {
        return this.createManipulatedGraph(this.getSemIm().getSemPm().getGraph());
    }

    public SemIm getManipulatedSemIm() {
        SemGraph graph = this.getSemIm().getSemPm().getGraph();
        SemGraph manipulatedGraph = this.createManipulatedGraph(graph);
        return SemIm.retainValues(this.getSemIm(), manipulatedGraph);
    }

    private SemGraph createManipulatedGraph(Graph graph) {
        SemGraph updatedGraph = new SemGraph(graph);
        for (int i = 0; i < this.evidence.getNumNodes(); ++i) {
            if (!this.evidence.isManipulated(i)) continue;
            Node node = this.evidence.getNode(i);
            List<Node> parents = updatedGraph.getParents(node);
            for (Node parent : parents) {
                if (parent.getNodeType() == NodeType.ERROR) continue;
                updatedGraph.removeEdge(node, parent);
            }
        }
        return updatedGraph;
    }
}

