/*
 * Decompiled with CFR 0.152.
 */
package edu.cmu.tetrad.sem;

import cern.colt.matrix.DoubleMatrix1D;
import cern.colt.matrix.DoubleMatrix2D;
import cern.colt.matrix.impl.DenseDoubleMatrix1D;
import cern.colt.matrix.linalg.Algebra;
import cern.jet.math.Functions;
import edu.cmu.tetrad.graph.Graph;
import edu.cmu.tetrad.graph.Node;
import edu.cmu.tetrad.graph.NodeType;
import edu.cmu.tetrad.graph.SemGraph;
import edu.cmu.tetrad.sem.SemEvidence;
import edu.cmu.tetrad.sem.SemIm;
import edu.cmu.tetrad.util.TetradSerializable;
import java.util.List;

public class SemUpdater
implements TetradSerializable {
    static final long serialVersionUID = 23L;
    private SemEvidence evidence;
    private SemIm semIm;

    public SemUpdater(SemIm semIm) {
        if (semIm == null) {
            throw new NullPointerException();
        }
        this.setEvidence(new SemEvidence(semIm));
        this.semIm = semIm;
    }

    public static SemUpdater serializableInstance() {
        return new SemUpdater(SemIm.serializableInstance());
    }

    public SemEvidence getEvidence() {
        return this.evidence;
    }

    public void setEvidence(SemEvidence evidence) {
        if (evidence == null) {
            throw new NullPointerException();
        }
        this.evidence = evidence;
    }

    public SemIm getSemIm() {
        return this.semIm;
    }

    public SemIm getUpdatedSemIm() {
        int i;
        Algebra algebra = new Algebra();
        SemIm manipulatedSemIm = this.getManipulatedSemIm();
        double[] means = new double[manipulatedSemIm.getVariableNodes().size()];
        for (int i2 = 0; i2 < means.length; ++i2) {
            means[i2] = manipulatedSemIm.getMean(manipulatedSemIm.getVariableNodes().get(i2));
        }
        DenseDoubleMatrix1D mu = new DenseDoubleMatrix1D(means);
        DoubleMatrix2D sigma = manipulatedSemIm.getImplCovar();
        SemEvidence evidence = this.getEvidence();
        List<Node> nodesInEvidence = evidence.getNodesInEvidence();
        int[] x2 = new int[nodesInEvidence.size()];
        DenseDoubleMatrix1D a = new DenseDoubleMatrix1D(nodesInEvidence.size());
        for (i = 0; i < nodesInEvidence.size(); ++i) {
            Node _node = nodesInEvidence.get(i);
            x2[i] = evidence.getNodeIndex(_node);
        }
        for (i = 0; i < nodesInEvidence.size(); ++i) {
            int j = evidence.getNodeIndex(nodesInEvidence.get(i));
            a.set(i, evidence.getProposition().getValue(j));
        }
        int[] x1 = new int[sigma.rows()];
        for (int i3 = 0; i3 < sigma.rows(); ++i3) {
            x1[i3] = i3;
        }
        DoubleMatrix2D sigma12 = sigma.viewSelection(x1, x2);
        DoubleMatrix2D sigma22 = sigma.viewSelection(x2, x2);
        DoubleMatrix2D inv_sigma22 = algebra.inverse(sigma22);
        DoubleMatrix2D temp1 = algebra.mult(sigma12, inv_sigma22);
        DoubleMatrix1D mu1 = mu.viewSelection(x1);
        DoubleMatrix1D mu2 = mu.viewSelection(x2);
        DoubleMatrix1D temp4 = a.copy().assign(mu2, Functions.minus);
        DoubleMatrix1D temp5 = algebra.mult(temp1, temp4);
        DoubleMatrix1D muBar = mu1.copy().assign(temp5, Functions.plus);
        DoubleMatrix2D sigma2 = manipulatedSemIm.getErrCovar();
        System.out.println("Restricted sigma: " + sigma2);
        return manipulatedSemIm.updatedIm(sigma2, muBar);
    }

    public Graph getManipulatedGraph() {
        return this.createManipulatedGraph(this.getSemIm().getSemPm().getGraph());
    }

    public SemIm getManipulatedSemIm() {
        SemGraph graph = this.getSemIm().getSemPm().getGraph();
        SemGraph manipulatedGraph = this.createManipulatedGraph(graph);
        return SemIm.retainValues(this.getSemIm(), manipulatedGraph);
    }

    private SemGraph createManipulatedGraph(Graph graph) {
        SemGraph updatedGraph = new SemGraph(graph);
        for (int i = 0; i < this.evidence.getNumNodes(); ++i) {
            if (!this.evidence.isManipulated(i)) continue;
            Node node = this.evidence.getNode(i);
            List<Node> parents = updatedGraph.getParents(node);
            for (Node parent : parents) {
                if (parent.getNodeType() == NodeType.ERROR) continue;
                updatedGraph.removeEdge(node, parent);
            }
        }
        return updatedGraph;
    }
}

