/*
 * Decompiled with CFR 0.152.
 */
package edu.cmu.tetrad.sem;

import cern.colt.matrix.DoubleFactory2D;
import cern.colt.matrix.DoubleMatrix1D;
import cern.colt.matrix.DoubleMatrix2D;
import cern.colt.matrix.impl.DenseDoubleMatrix1D;
import cern.colt.matrix.impl.DenseDoubleMatrix2D;
import cern.colt.matrix.linalg.Algebra;
import edu.cmu.tetrad.graph.EdgeListGraph;
import edu.cmu.tetrad.graph.Graph;
import edu.cmu.tetrad.graph.Node;
import edu.cmu.tetrad.graph.NodeType;
import edu.cmu.tetrad.sem.SemOptimizerEm;
import edu.cmu.tetrad.util.MatrixUtils;
import edu.cmu.tetrad.util.TetradLogger;
import edu.cmu.tetrad.util.TetradSerializable;
import java.util.ArrayList;
import java.util.List;

public final class GesOptimizationEm
implements TetradSerializable {
    static final long serialVersionUID = 23L;
    private static final double FUNC_TOLERANCE = 1.0E-4;
    private ArrayList<Node> nodes;
    private int numNodes;
    private DoubleMatrix2D sampleCovar;
    private DoubleMatrix2D edgeCoef;
    private DoubleMatrix2D errorCovar;
    private DoubleMatrix2D implCovar;
    private double logDetSample;
    private double[][] expectedCovariance;
    private Graph graph;
    private transient double[][] yCov;
    private transient double[][] yCovModel;
    private transient double[][] yzCovModel;
    private transient double[][] zCovModel;
    private transient int numObserved;
    private transient int numLatent;
    private transient int[] idxLatent;
    private transient int[] idxObserved;
    private transient int[][] parents;
    private transient Node[] errorParent;
    private transient double[][] nodeParentsCov;
    private transient double[][][] parentsCov;
    private int sampleSize;

    public GesOptimizationEm() {
    }

    public static SemOptimizerEm serializableInstance() {
        return new SemOptimizerEm();
    }

    public GesOptimizationEm(List<Node> nodes, DoubleMatrix2D sampleCovar, int sampleSize) {
        this.nodes = new ArrayList<Node>(nodes);
        this.sampleSize = sampleSize;
        this.graph = new EdgeListGraph(this.nodes);
        this.numNodes = this.nodes.size();
        this.sampleCovar = sampleCovar;
        this.edgeCoef = new DenseDoubleMatrix2D(this.graph.getNumNodes(), this.graph.getNumNodes());
        this.errorCovar = DoubleFactory2D.dense.identity(this.graph.getNumNodes());
    }

    public void setGraph(Graph graph) {
        if (!((Object)graph.getNodes()).equals(this.graph.getNodes())) {
            throw new IllegalArgumentException("Nodes of graph must be identical and in the same order.");
        }
        this.graph = graph;
    }

    public void optimize() {
        double score;
        this.edgeCoef = new DenseDoubleMatrix2D(this.graph.getNumNodes(), this.graph.getNumNodes());
        this.errorCovar = DoubleFactory2D.dense.identity(this.graph.getNumNodes());
        this.initialize();
        this.updateMatrices();
        double newScore = this.scoreSemIm();
        do {
            score = newScore;
            this.expectation();
            this.maximization();
            this.updateMatrices();
        } while ((newScore = this.scoreSemIm()) - score > 1.0E-4);
    }

    public void optimizeRegression() {
        List<Node> nodes = this.graph.getNodes();
        Algebra algebra = new Algebra();
        for (Node node : nodes) {
            if (node.getNodeType() != NodeType.MEASURED) continue;
            int idx = nodes.indexOf(node);
            List<Node> parents = this.graph.getParents(node);
            Node errorParent = node;
            for (int i = 0; i < parents.size(); ++i) {
                Node nextParent = parents.get(i);
                if (nextParent.getNodeType() != NodeType.ERROR) continue;
                errorParent = nextParent;
                parents.remove(nextParent);
                break;
            }
            double variance = this.sampleCovar.get(idx, idx);
            if (parents.size() > 0) {
                DenseDoubleMatrix1D nodeParentsCov = new DenseDoubleMatrix1D(parents.size());
                DenseDoubleMatrix2D parentsCov = new DenseDoubleMatrix2D(parents.size(), parents.size());
                for (int i = 0; i < parents.size(); ++i) {
                    int idx2 = nodes.indexOf(parents.get(i));
                    nodeParentsCov.set(i, this.sampleCovar.get(idx, idx2));
                    for (int j = i; j < parents.size(); ++j) {
                        int idx3 = nodes.indexOf(parents.get(j));
                        parentsCov.set(i, j, this.sampleCovar.get(idx2, idx3));
                        parentsCov.set(j, i, this.sampleCovar.get(idx2, idx3));
                    }
                }
                DoubleMatrix1D edges = algebra.mult(algebra.inverse(parentsCov), (DoubleMatrix1D)nodeParentsCov);
                for (int i = 0; i < edges.size(); ++i) {
                    int idx2 = nodes.indexOf(parents.get(i));
                    this.edgeCoef.set(idx2, idx, edges.get(i));
                }
                variance -= algebra.mult(nodeParentsCov, edges);
            }
            this.errorCovar.set(idx, idx, variance);
            TetradLogger.getInstance().log("info", "FML = " + this.getFml());
        }
    }

    public double[][] getExpectedCovarianceMatrix() {
        return this.expectedCovariance;
    }

    private int getDof() {
        return this.nodes.size() * (this.nodes.size() + 1) / 2 - (this.graph.getNumNodes() + this.graph.getNumEdges());
    }

    public double getChiSquare() {
        return (double)(this.sampleSize - 1) * this.getFml();
    }

    private void initialize() {
        int i;
        this.yCov = this.sampleCovar.toArray();
        this.numLatent = 0;
        this.numObserved = 0;
        for (int i2 = 0; i2 < this.graph.getNodes().size(); ++i2) {
            Node node = this.graph.getNodes().get(i2);
            if (node.getNodeType() == NodeType.LATENT) {
                ++this.numLatent;
                continue;
            }
            if (node.getNodeType() == NodeType.MEASURED) {
                ++this.numObserved;
                continue;
            }
            if (node.getNodeType() == NodeType.ERROR) break;
        }
        if (this.numLatent == 0) {
            this.numLatent = 1;
        }
        this.idxLatent = new int[this.numLatent];
        this.idxObserved = new int[this.numObserved];
        int countLatent = 0;
        int countObserved = 0;
        for (i = 0; i < this.graph.getNodes().size(); ++i) {
            Node node = this.graph.getNodes().get(i);
            if (node.getNodeType() == NodeType.LATENT) {
                this.idxLatent[countLatent++] = i;
                continue;
            }
            if (node.getNodeType() == NodeType.MEASURED) {
                this.idxObserved[countObserved++] = i;
                continue;
            }
            if (node.getNodeType() == NodeType.ERROR) break;
        }
        this.expectedCovariance = new double[this.numObserved + this.numLatent][this.numObserved + this.numLatent];
        for (i = 0; i < this.numObserved; ++i) {
            for (int j = i; j < this.numObserved; ++j) {
                double d = this.yCov[i][j];
                this.expectedCovariance[this.idxObserved[j]][this.idxObserved[i]] = d;
                this.expectedCovariance[this.idxObserved[i]][this.idxObserved[j]] = d;
            }
        }
        this.yCovModel = new double[this.numObserved][this.numObserved];
        this.yzCovModel = new double[this.numObserved][this.numLatent];
        this.zCovModel = new double[this.numLatent][this.numLatent];
        this.parents = new int[this.numLatent + this.numObserved][];
        this.errorParent = new Node[this.numLatent + this.numObserved];
        this.nodeParentsCov = new double[this.numLatent + this.numObserved][];
        this.parentsCov = new double[this.numLatent + this.numObserved][][];
        for (Node node : this.graph.getNodes()) {
            int i3;
            if (node.getNodeType() == NodeType.ERROR) break;
            int idx = this.graph.getNodes().indexOf(node);
            List<Node> parents = this.graph.getParents(node);
            this.errorParent[idx] = node;
            for (i3 = 0; i3 < parents.size(); ++i3) {
                Node nextParent = parents.get(i3);
                if (nextParent.getNodeType() != NodeType.ERROR) continue;
                this.errorParent[idx] = nextParent;
                parents.remove(nextParent);
                break;
            }
            if (parents.size() > 0) {
                this.parents[idx] = new int[parents.size()];
                this.nodeParentsCov[idx] = new double[parents.size()];
                this.parentsCov[idx] = new double[parents.size()][parents.size()];
                for (i3 = 0; i3 < parents.size(); ++i3) {
                    this.parents[idx][i3] = this.graph.getNodes().indexOf(parents.get(i3));
                }
                continue;
            }
            this.parents[idx] = null;
        }
    }

    private void expectation() {
        double[][] delta = MatrixUtils.product(MatrixUtils.inverse(this.yCovModel), this.yzCovModel);
        double[][] Delta = MatrixUtils.subtract(this.zCovModel, MatrixUtils.product(MatrixUtils.transpose(this.yzCovModel), delta));
        double[][] yzE = MatrixUtils.product(this.yCov, delta);
        double[][] zzE = MatrixUtils.sum(MatrixUtils.product(MatrixUtils.product(MatrixUtils.transpose(delta), this.yCov), delta), Delta);
        for (int i = 0; i < this.numLatent; ++i) {
            int j;
            for (j = i; j < this.numLatent; ++j) {
                double d = zzE[i][j];
                this.expectedCovariance[this.idxLatent[j]][this.idxLatent[i]] = d;
                this.expectedCovariance[this.idxLatent[i]][this.idxLatent[j]] = d;
            }
            for (j = 0; j < this.numObserved; ++j) {
                double d = yzE[j][i];
                this.expectedCovariance[this.idxObserved[j]][this.idxLatent[i]] = d;
                this.expectedCovariance[this.idxLatent[i]][this.idxObserved[j]] = d;
            }
        }
    }

    private void maximization() {
        for (Node node : this.graph.getNodes()) {
            if (node.getNodeType() == NodeType.ERROR) break;
            int idx = this.graph.getNodes().indexOf(node);
            double variance = this.expectedCovariance[idx][idx];
            if (this.parents[idx] != null) {
                for (int i = 0; i < this.parents[idx].length; ++i) {
                    int idx2 = this.parents[idx][i];
                    this.nodeParentsCov[idx][i] = this.expectedCovariance[idx][idx2];
                    for (int j = i; j < this.parents[idx].length; ++j) {
                        int idx3 = this.parents[idx][j];
                        double d = this.expectedCovariance[idx2][idx3];
                        this.parentsCov[idx][j][i] = d;
                        this.parentsCov[idx][i][j] = d;
                    }
                }
                double[] edges = MatrixUtils.product(MatrixUtils.inverse(this.parentsCov[idx]), this.nodeParentsCov[idx]);
                for (int i = 0; i < edges.length; ++i) {
                    int idx2 = this.parents[idx][i];
                    try {
                        this.edgeCoef.set(idx2, idx, edges[i]);
                        continue;
                    }
                    catch (IllegalArgumentException e) {
                        // empty catch block
                    }
                }
                variance -= MatrixUtils.innerProduct(this.nodeParentsCov[idx], edges);
            }
            try {
                this.errorCovar.set(idx, idx, variance);
            }
            catch (IllegalArgumentException e) {}
        }
    }

    private void updateMatrices() {
        int j;
        int i;
        this.computeImpliedCovar();
        double[][] impliedCovar = this.implCovar.toArray();
        for (i = 0; i < this.numObserved; ++i) {
            for (j = i; j < this.numObserved; ++j) {
                double d = impliedCovar[this.idxObserved[i]][this.idxObserved[j]];
                this.yCovModel[j][i] = d;
                this.yCovModel[i][j] = d;
            }
            for (j = 0; j < this.numLatent; ++j) {
                this.yzCovModel[i][j] = impliedCovar[this.idxObserved[i]][this.idxLatent[j]];
            }
        }
        for (i = 0; i < this.numLatent; ++i) {
            for (j = i; j < this.numLatent; ++j) {
                double d = impliedCovar[this.idxLatent[i]][this.idxLatent[j]];
                this.zCovModel[j][i] = d;
                this.zCovModel[i][j] = d;
            }
        }
    }

    private double scoreSemIm() {
        double score = this.getFml();
        if (Double.isNaN(score)) {
            score = Double.POSITIVE_INFINITY;
        }
        TetradLogger.getInstance().log("info", "FML = " + score);
        return -score;
    }

    public double getFml() {
        try {
            this.computeImpliedCovar();
        }
        catch (Exception e) {
            e.printStackTrace();
            return Double.NaN;
        }
        return this.logDet(this.implCovar) + this.traceSSigmaInv(this.sampleCovar, this.implCovar) - this.logDetSample() - (double)this.numNodes;
    }

    private void computeImpliedCovar() {
        DoubleMatrix2D edgeCoefT = new Algebra().transpose(this.edgeCoef);
        this.implCovar = MatrixUtils.impliedCovarC(edgeCoefT, this.errorCovar);
    }

    private double logDet(DoubleMatrix2D implCovarMeas) {
        return Math.log(MatrixUtils.determinant(implCovarMeas));
    }

    private double traceSSigmaInv(DoubleMatrix2D s, DoubleMatrix2D sigma) {
        DoubleMatrix2D inverse = new Algebra().inverse(sigma);
        DoubleMatrix2D product = new Algebra().mult(s, inverse);
        double v = MatrixUtils.trace(product);
        if (v < 0.0) {
            throw new IllegalArgumentException("Trace was negative.");
        }
        return v;
    }

    private double logDetSample() {
        if (this.logDetSample == 0.0 && this.sampleCovar != null) {
            double det = MatrixUtils.determinant(this.sampleCovar);
            this.logDetSample = Math.log(det);
        }
        return this.logDetSample;
    }
}

