/*
 * Decompiled with CFR 0.152.
 */
package edu.cmu.tetrad.sem;

import cern.colt.matrix.DoubleMatrix2D;
import edu.cmu.tetrad.graph.Graph;
import edu.cmu.tetrad.graph.Node;
import edu.cmu.tetrad.graph.NodeType;
import edu.cmu.tetrad.sem.SemIm;
import edu.cmu.tetrad.sem.SemOptimizer;
import edu.cmu.tetrad.util.MatrixUtils;
import edu.cmu.tetrad.util.TetradLogger;
import java.util.List;

public class SemOptimizerEm
implements SemOptimizer {
    static final long serialVersionUID = 23L;
    private static final double FUNC_TOLERANCE = 1.0E-4;
    private SemIm semIm;
    private double[][] expectedCovariance;
    private transient Graph graph;
    private transient double[][] yCov;
    private transient double[][] yCovModel;
    private transient double[][] yzCovModel;
    private transient double[][] zCovModel;
    private transient int numObserved;
    private transient int numLatent;
    private transient int[] idxLatent;
    private transient int[] idxObserved;
    private transient int[][] parents;
    private transient Node[] errorParent;
    private transient double[][] nodeParentsCov;
    private transient double[][][] parentsCov;

    public static SemOptimizerEm serializableInstance() {
        return new SemOptimizerEm();
    }

    @Override
    public void optimize(SemIm semIm) {
        double score;
        if (semIm == null) {
            throw new IllegalArgumentException();
        }
        this.initialize(semIm);
        this.updateMatrices();
        double newScore = this.scoreSemIm();
        do {
            score = newScore;
            this.expectation();
            this.maximization();
            this.updateMatrices();
        } while ((newScore = this.scoreSemIm()) - score > 1.0E-4);
    }

    public double[][] getExpectedCovarianceMatrix() {
        return this.expectedCovariance;
    }

    private void initialize(SemIm semIm) {
        int i;
        this.semIm = semIm;
        this.graph = semIm.getSemPm().getGraph();
        this.yCov = semIm.getSampleCovar().toArray();
        this.numLatent = 0;
        this.numObserved = 0;
        for (int i2 = 0; i2 < this.graph.getNodes().size(); ++i2) {
            Node node = this.graph.getNodes().get(i2);
            if (node.getNodeType() == NodeType.LATENT) {
                ++this.numLatent;
                continue;
            }
            if (node.getNodeType() == NodeType.MEASURED) {
                ++this.numObserved;
                continue;
            }
            if (node.getNodeType() == NodeType.ERROR) break;
        }
        if (this.numLatent == 0) {
            this.numLatent = 1;
        }
        this.idxLatent = new int[this.numLatent];
        this.idxObserved = new int[this.numObserved];
        int countLatent = 0;
        int countObserved = 0;
        for (i = 0; i < this.graph.getNodes().size(); ++i) {
            Node node = this.graph.getNodes().get(i);
            if (node.getNodeType() == NodeType.LATENT) {
                this.idxLatent[countLatent++] = i;
                continue;
            }
            if (node.getNodeType() == NodeType.MEASURED) {
                this.idxObserved[countObserved++] = i;
                continue;
            }
            if (node.getNodeType() == NodeType.ERROR) break;
        }
        this.expectedCovariance = new double[this.numObserved + this.numLatent][this.numObserved + this.numLatent];
        for (i = 0; i < this.numObserved; ++i) {
            for (int j = i; j < this.numObserved; ++j) {
                double d = this.yCov[i][j];
                this.expectedCovariance[this.idxObserved[j]][this.idxObserved[i]] = d;
                this.expectedCovariance[this.idxObserved[i]][this.idxObserved[j]] = d;
            }
        }
        this.yCovModel = new double[this.numObserved][this.numObserved];
        this.yzCovModel = new double[this.numObserved][this.numLatent];
        this.zCovModel = new double[this.numLatent][this.numLatent];
        this.parents = new int[this.numLatent + this.numObserved][];
        this.errorParent = new Node[this.numLatent + this.numObserved];
        this.nodeParentsCov = new double[this.numLatent + this.numObserved][];
        this.parentsCov = new double[this.numLatent + this.numObserved][][];
        for (Node node : this.graph.getNodes()) {
            int i3;
            if (node.getNodeType() == NodeType.ERROR) break;
            int idx = this.graph.getNodes().indexOf(node);
            List<Node> parents = this.graph.getParents(node);
            this.errorParent[idx] = node;
            for (i3 = 0; i3 < parents.size(); ++i3) {
                Node nextParent = parents.get(i3);
                if (nextParent.getNodeType() != NodeType.ERROR) continue;
                this.errorParent[idx] = nextParent;
                parents.remove(nextParent);
                break;
            }
            if (parents.size() > 0) {
                this.parents[idx] = new int[parents.size()];
                this.nodeParentsCov[idx] = new double[parents.size()];
                this.parentsCov[idx] = new double[parents.size()][parents.size()];
                for (i3 = 0; i3 < parents.size(); ++i3) {
                    this.parents[idx][i3] = this.graph.getNodes().indexOf(parents.get(i3));
                }
                continue;
            }
            this.parents[idx] = null;
        }
    }

    private void expectation() {
        double[][] delta = MatrixUtils.product(MatrixUtils.inverse(this.yCovModel), this.yzCovModel);
        double[][] Delta = MatrixUtils.subtract(this.zCovModel, MatrixUtils.product(MatrixUtils.transpose(this.yzCovModel), delta));
        double[][] yzE = MatrixUtils.product(this.yCov, delta);
        double[][] zzE = MatrixUtils.sum(MatrixUtils.product(MatrixUtils.product(MatrixUtils.transpose(delta), this.yCov), delta), Delta);
        for (int i = 0; i < this.numLatent; ++i) {
            int j;
            for (j = i; j < this.numLatent; ++j) {
                double d = zzE[i][j];
                this.expectedCovariance[this.idxLatent[j]][this.idxLatent[i]] = d;
                this.expectedCovariance[this.idxLatent[i]][this.idxLatent[j]] = d;
            }
            for (j = 0; j < this.numObserved; ++j) {
                double d = yzE[j][i];
                this.expectedCovariance[this.idxObserved[j]][this.idxLatent[i]] = d;
                this.expectedCovariance[this.idxLatent[i]][this.idxObserved[j]] = d;
            }
        }
    }

    private void maximization() {
        for (Node node : this.graph.getNodes()) {
            if (node.getNodeType() == NodeType.ERROR) break;
            int idx = this.graph.getNodes().indexOf(node);
            double variance = this.expectedCovariance[idx][idx];
            if (this.parents[idx] != null) {
                for (int i = 0; i < this.parents[idx].length; ++i) {
                    int idx2 = this.parents[idx][i];
                    this.nodeParentsCov[idx][i] = this.expectedCovariance[idx][idx2];
                    for (int j = i; j < this.parents[idx].length; ++j) {
                        int idx3 = this.parents[idx][j];
                        double d = this.expectedCovariance[idx2][idx3];
                        this.parentsCov[idx][j][i] = d;
                        this.parentsCov[idx][i][j] = d;
                    }
                }
                double[] edges = MatrixUtils.product(MatrixUtils.inverse(this.parentsCov[idx]), this.nodeParentsCov[idx]);
                for (int i = 0; i < edges.length; ++i) {
                    int idx2 = this.parents[idx][i];
                    try {
                        this.semIm.setParamValue(this.graph.getNodes().get(idx2), node, edges[i]);
                        continue;
                    }
                    catch (IllegalArgumentException e) {
                        // empty catch block
                    }
                }
                variance -= MatrixUtils.innerProduct(this.nodeParentsCov[idx], edges);
            }
            try {
                this.semIm.setParamValue(this.errorParent[idx], this.errorParent[idx], variance);
            }
            catch (IllegalArgumentException e) {}
        }
    }

    private void updateMatrices() {
        int j;
        int i;
        DoubleMatrix2D implCovarC = this.semIm.getImplCovar();
        double[][] impliedCovar = implCovarC.toArray();
        for (i = 0; i < this.numObserved; ++i) {
            for (j = i; j < this.numObserved; ++j) {
                double d = impliedCovar[this.idxObserved[i]][this.idxObserved[j]];
                this.yCovModel[j][i] = d;
                this.yCovModel[i][j] = d;
            }
            for (j = 0; j < this.numLatent; ++j) {
                this.yzCovModel[i][j] = impliedCovar[this.idxObserved[i]][this.idxLatent[j]];
            }
        }
        for (i = 0; i < this.numLatent; ++i) {
            for (j = i; j < this.numLatent; ++j) {
                double d = impliedCovar[this.idxLatent[i]][this.idxLatent[j]];
                this.zCovModel[j][i] = d;
                this.zCovModel[i][j] = d;
            }
        }
    }

    private double scoreSemIm() {
        double score = this.semIm.getFml();
        if (Double.isNaN(score)) {
            score = Double.POSITIVE_INFINITY;
        }
        TetradLogger.getInstance().log("optimization", "FML = " + score);
        return -score;
    }
}

