/*
 * Decompiled with CFR 0.152.
 */
package edu.cmu.tetrad.sem;

import cern.colt.matrix.DoubleMatrix1D;
import cern.colt.matrix.DoubleMatrix2D;
import cern.colt.matrix.impl.DenseDoubleMatrix1D;
import cern.colt.matrix.impl.DenseDoubleMatrix2D;
import cern.colt.matrix.linalg.Algebra;
import edu.cmu.tetrad.graph.Node;
import edu.cmu.tetrad.graph.NodeType;
import edu.cmu.tetrad.graph.SemGraph;
import edu.cmu.tetrad.sem.SemIm;
import edu.cmu.tetrad.sem.SemOptimizer;
import edu.cmu.tetrad.util.TetradLogger;
import java.util.List;

public class SemOptimizerRegression
implements SemOptimizer {
    static final long serialVersionUID = 23L;

    public static SemOptimizerRegression serializableInstance() {
        return new SemOptimizerRegression();
    }

    @Override
    public void optimize(SemIm semIm) {
        DoubleMatrix2D covar = semIm.getSampleCovar();
        SemGraph graph = semIm.getSemPm().getGraph();
        List<Node> nodes = graph.getNodes();
        Algebra algebra = new Algebra();
        for (Node node : nodes) {
            if (node.getNodeType() != NodeType.MEASURED) continue;
            int idx = nodes.indexOf(node);
            List<Node> parents = graph.getParents(node);
            Node errorParent = node;
            for (int i = 0; i < parents.size(); ++i) {
                Node nextParent = parents.get(i);
                if (nextParent.getNodeType() != NodeType.ERROR) continue;
                errorParent = nextParent;
                parents.remove(nextParent);
                break;
            }
            double variance = covar.get(idx, idx);
            if (parents.size() > 0) {
                DenseDoubleMatrix1D nodeParentsCov = new DenseDoubleMatrix1D(parents.size());
                DenseDoubleMatrix2D parentsCov = new DenseDoubleMatrix2D(parents.size(), parents.size());
                for (int i = 0; i < parents.size(); ++i) {
                    int idx2 = nodes.indexOf(parents.get(i));
                    nodeParentsCov.set(i, covar.get(idx, idx2));
                    for (int j = i; j < parents.size(); ++j) {
                        int idx3 = nodes.indexOf(parents.get(j));
                        parentsCov.set(i, j, covar.get(idx2, idx3));
                        parentsCov.set(j, i, covar.get(idx2, idx3));
                    }
                }
                DoubleMatrix1D edges = algebra.mult(algebra.inverse(parentsCov), (DoubleMatrix1D)nodeParentsCov);
                for (int i = 0; i < edges.size(); ++i) {
                    int idx2 = nodes.indexOf(parents.get(i));
                    semIm.setParamValue(nodes.get(idx2), node, edges.get(i));
                }
                variance -= algebra.mult(nodeParentsCov, edges);
            }
            semIm.setParamValue(errorParent, errorParent, variance);
            TetradLogger.getInstance().log("optimization", "FML = " + semIm.getFml());
        }
    }
}

