/*
 * Decompiled with CFR 0.152.
 */
package edu.cmu.tetrad.sem;

import edu.cmu.tetrad.data.DataUtils;
import edu.cmu.tetrad.graph.Node;
import edu.cmu.tetrad.graph.NodeType;
import edu.cmu.tetrad.graph.SemGraph;
import edu.cmu.tetrad.sem.ParamType;
import edu.cmu.tetrad.sem.Parameter;
import edu.cmu.tetrad.sem.SemIm;
import edu.cmu.tetrad.sem.SemOptimizer;
import edu.cmu.tetrad.util.Matrix;
import edu.cmu.tetrad.util.RandomUtil;
import edu.cmu.tetrad.util.TetradLogger;
import edu.cmu.tetrad.util.Vector;
import java.util.ArrayList;
import java.util.List;

public class SemOptimizerEm
implements SemOptimizer {
    private static final long serialVersionUID = 23L;
    private static final double FUNC_TOLERANCE = 1.0E-6;
    private SemIm semIm;
    private SemGraph graph;
    private Matrix yCov;
    private Matrix yCovModel;
    private Matrix yzCovModel;
    private Matrix zCovModel;
    private Matrix expectedCov;
    private int numObserved;
    private int numLatent;
    private int[] idxLatent;
    private int[] idxObserved;
    private int[][] parents;
    private Node[] errorParent;
    private double[][] nodeParentsCov;
    private double[][][] parentsCov;
    private int numRestarts = 1;

    public static SemOptimizerEm serializableInstance() {
        return new SemOptimizerEm();
    }

    @Override
    public void optimize(SemIm semIm) {
        Matrix sampleCovar;
        if (this.numRestarts < 1) {
            this.numRestarts = 1;
        }
        if ((sampleCovar = semIm.getSampleCovar()) == null) {
            throw new NullPointerException("Sample covar has not been set.");
        }
        if (DataUtils.containsMissingValue(sampleCovar)) {
            throw new IllegalArgumentException("Please remove or impute missing values.");
        }
        if (this.numRestarts < 1) {
            this.numRestarts = 1;
        }
        double min = semIm.getChiSquare();
        SemIm _sem = semIm;
        for (int count = 0; count < this.numRestarts; ++count) {
            TetradLogger.getInstance().log("details", "Trial " + (count + 1));
            SemIm _sem2 = new SemIm(semIm);
            List<Parameter> freeParameters = _sem2.getFreeParameters();
            double[] p = new double[freeParameters.size()];
            for (int i = 0; i < freeParameters.size(); ++i) {
                p[i] = freeParameters.get(i).getType() == ParamType.VAR ? RandomUtil.getInstance().nextUniform(0.0, 3.0) : RandomUtil.getInstance().nextUniform(-2.0, 2.0);
            }
            _sem2.setFreeParamValues(p);
            this.optimize2(_sem2);
            double chisq = _sem2.getChiSquare();
            TetradLogger.getInstance().log("details", "chisq = " + chisq);
            if (!(chisq < min)) continue;
            min = chisq;
            _sem = _sem2;
        }
        for (Parameter param : semIm.getFreeParameters()) {
            try {
                Node nodeA = param.getNodeA();
                Node nodeB = param.getNodeB();
                Node _nodeA = _sem.getVariableNode(nodeA.getName());
                Node _nodeB = _sem.getVariableNode(nodeB.getName());
                double value = _sem.getParamValue(_nodeA, _nodeB);
                semIm.setParamValue(param, value);
            }
            catch (Exception e) {
                e.printStackTrace();
                throw new RuntimeException(e);
            }
        }
    }

    private void optimize2(SemIm semIm) {
        double score;
        boolean showErrors = semIm.getSemPm().getGraph().isShowErrorTerms();
        semIm.getSemPm().getGraph().setShowErrorTerms(true);
        this.initialize(semIm);
        this.updateMatrices();
        double newScore = this.scoreSemIm();
        do {
            score = newScore;
            this.expectation();
            this.maximization();
            this.updateMatrices();
        } while ((newScore = this.scoreSemIm()) > score + 1.0E-6);
        semIm.getSemPm().getGraph().setShowErrorTerms(showErrors);
    }

    @Override
    public int getNumRestarts() {
        return this.numRestarts;
    }

    @Override
    public void setNumRestarts(int numRestarts) {
        this.numRestarts = numRestarts;
    }

    private void initialize(SemIm semIm) {
        int i;
        this.semIm = semIm;
        this.graph = semIm.getSemPm().getGraph();
        this.yCov = semIm.getSampleCovar();
        if (this.yCov == null) {
            throw new NullPointerException("Sample covar has not been set.");
        }
        this.numObserved = 0;
        this.numLatent = 0;
        List<Node> nodes = this.graph.getNodes();
        for (Node node : nodes) {
            if (node.getNodeType() == NodeType.LATENT) {
                ++this.numLatent;
                continue;
            }
            if (node.getNodeType() != NodeType.MEASURED) continue;
            ++this.numObserved;
        }
        if (this.numLatent == 0) {
            throw new IllegalArgumentException("Need at least one latent for the EM estimator.");
        }
        this.idxLatent = new int[this.numLatent];
        this.idxObserved = new int[this.numObserved];
        int countLatent = 0;
        int countObserved = 0;
        for (i = 0; i < nodes.size(); ++i) {
            Node node = nodes.get(i);
            if (node.getNodeType() == NodeType.LATENT) {
                this.idxLatent[countLatent++] = i;
                continue;
            }
            if (node.getNodeType() != NodeType.MEASURED) continue;
            this.idxObserved[countObserved++] = i;
        }
        this.expectedCov = new Matrix(this.numObserved + this.numLatent, this.numObserved + this.numLatent);
        for (i = 0; i < this.numObserved; ++i) {
            for (int j = i; j < this.numObserved; ++j) {
                this.expectedCov.set(this.idxObserved[i], this.idxObserved[j], this.yCov.get(i, j));
                this.expectedCov.set(this.idxObserved[j], this.idxObserved[i], this.yCov.get(i, j));
            }
        }
        this.yCovModel = new Matrix(this.numObserved, this.numObserved);
        this.yzCovModel = new Matrix(this.numObserved, this.numLatent);
        this.zCovModel = new Matrix(this.numLatent, this.numLatent);
        this.parents = new int[this.numLatent + this.numObserved][];
        this.errorParent = new Node[this.numLatent + this.numObserved];
        this.nodeParentsCov = new double[this.numLatent + this.numObserved][];
        this.parentsCov = new double[this.numLatent + this.numObserved][][];
        for (Node node : nodes) {
            int i2;
            if (node.getNodeType() == NodeType.ERROR) continue;
            int idx = nodes.indexOf(node);
            ArrayList<Node> _parents = new ArrayList<Node>(this.graph.getParents(node));
            for (i2 = 0; i2 < _parents.size(); ++i2) {
                Node nextParent = (Node)_parents.get(i2);
                if (nextParent.getNodeType() != NodeType.ERROR) continue;
                this.errorParent[idx] = nextParent;
                _parents.remove(nextParent);
                break;
            }
            if (_parents.size() > 0) {
                this.parents[idx] = new int[_parents.size()];
                this.nodeParentsCov[idx] = new double[_parents.size()];
                this.parentsCov[idx] = new double[_parents.size()][_parents.size()];
                for (i2 = 0; i2 < _parents.size(); ++i2) {
                    this.parents[idx][i2] = nodes.indexOf(_parents.get(i2));
                }
                continue;
            }
            this.parents[idx] = null;
        }
    }

    public String toString() {
        return "Sem Optimizer EM";
    }

    private void expectation() {
        int j;
        int i;
        Matrix bYZModel = this.yCovModel.inverse().times(this.yzCovModel);
        Matrix yzCovPred = this.yCov.times(bYZModel);
        Matrix zCovModel = this.yzCovModel.transpose().times(bYZModel);
        Matrix zCovDiff = this.zCovModel.minus(zCovModel);
        Matrix CzPred = yzCovPred.transpose().times(bYZModel);
        Matrix newCz = CzPred.plus(zCovDiff);
        for (i = 0; i < this.numLatent; ++i) {
            for (j = i; j < this.numLatent; ++j) {
                this.expectedCov.set(this.idxLatent[i], this.idxLatent[j], newCz.get(i, j));
                this.expectedCov.set(this.idxLatent[j], this.idxLatent[i], newCz.get(j, i));
            }
        }
        for (i = 0; i < this.numLatent; ++i) {
            for (j = 0; j < this.numObserved; ++j) {
                double v = yzCovPred.get(j, i);
                this.expectedCov.set(this.idxLatent[i], this.idxObserved[j], v);
                this.expectedCov.set(this.idxObserved[j], this.idxLatent[i], v);
            }
        }
    }

    private void maximization() {
        List<Node> nodes = this.graph.getNodes();
        for (Node node : this.graph.getNodes()) {
            if (node.getNodeType() == NodeType.ERROR) continue;
            int idx = nodes.indexOf(node);
            double variance = this.expectedCov.get(idx, idx);
            if (this.parents[idx] != null) {
                for (int i = 0; i < this.parents[idx].length; ++i) {
                    int idx2 = this.parents[idx][i];
                    this.nodeParentsCov[idx][i] = this.expectedCov.get(idx, idx2);
                    for (int j = i; j < this.parents[idx].length; ++j) {
                        int idx3 = this.parents[idx][j];
                        this.parentsCov[idx][i][j] = this.expectedCov.get(idx2, idx3);
                        this.parentsCov[idx][j][i] = this.expectedCov.get(idx3, idx2);
                    }
                }
                Vector coefs = new Matrix(this.parentsCov[idx]).inverse().times(new Vector(this.nodeParentsCov[idx]));
                for (int i = 0; i < coefs.size(); ++i) {
                    this.semIm.getSemPm().getParameter(nodes.get(this.parents[idx][i]), node);
                    if (this.semIm.getSemPm().getParameter(nodes.get(this.parents[idx][i]), node).isFixed()) continue;
                    this.semIm.setEdgeCoef(nodes.get(this.parents[idx][i]), node, coefs.get(i));
                }
                variance -= new Vector(this.nodeParentsCov[idx]).dotProduct(coefs);
            }
            if (this.semIm.getSemPm().getParameter(this.errorParent[idx], this.errorParent[idx]).isFixed()) continue;
            this.semIm.setErrCovar(this.errorParent[idx], variance);
        }
    }

    private void updateMatrices() {
        int j;
        int i;
        Matrix impliedCovar = this.semIm.getImplCovar(true);
        for (i = 0; i < this.numObserved; ++i) {
            for (j = i; j < this.numObserved; ++j) {
                this.yCovModel.set(i, j, impliedCovar.get(this.idxObserved[i], this.idxObserved[j]));
                this.yCovModel.set(j, i, impliedCovar.get(this.idxObserved[i], this.idxObserved[j]));
            }
            for (j = 0; j < this.numLatent; ++j) {
                this.yzCovModel.set(i, j, impliedCovar.get(this.idxObserved[i], this.idxLatent[j]));
            }
        }
        for (i = 0; i < this.numLatent; ++i) {
            for (j = i; j < this.numLatent; ++j) {
                this.zCovModel.set(i, j, impliedCovar.get(this.idxLatent[i], this.idxLatent[j]));
                this.zCovModel.set(j, i, impliedCovar.get(this.idxLatent[i], this.idxLatent[j]));
            }
        }
    }

    private double scoreSemIm() {
        return -this.semIm.getScore();
    }
}

