/*
 * Decompiled with CFR 0.152.
 */
package edu.cmu.tetrad.sem;

import cern.colt.list.DoubleArrayList;
import cern.jet.stat.Descriptive;
import edu.cmu.tetrad.data.CovarianceMatrix;
import edu.cmu.tetrad.data.DataSet;
import edu.cmu.tetrad.graph.GraphUtils;
import edu.cmu.tetrad.graph.Node;
import edu.cmu.tetrad.graph.NodeType;
import edu.cmu.tetrad.sem.ParamType;
import edu.cmu.tetrad.sem.Parameter;
import edu.cmu.tetrad.sem.SemIm;
import edu.cmu.tetrad.sem.SemOptimizer;
import edu.cmu.tetrad.sem.SemOptimizerEm;
import edu.cmu.tetrad.sem.SemOptimizerPalCds;
import edu.cmu.tetrad.sem.SemOptimizerRegression;
import edu.cmu.tetrad.sem.SemPm;
import edu.cmu.tetrad.util.MatrixUtils;
import edu.cmu.tetrad.util.TetradSerializable;
import java.io.IOException;
import java.io.ObjectInputStream;
import java.util.List;

public final class SemEstimator
implements TetradSerializable {
    static final long serialVersionUID = 23L;
    private SemPm semPm;
    private CovarianceMatrix covMatrix;
    private SemOptimizer semOptimizer;
    private SemIm estimatedSem;
    private DataSet dataSet;
    private SemIm trueSemIm;
    private boolean checkPositiveDefinite;

    public SemEstimator(DataSet dataSet, SemPm semPm) {
        this(dataSet, semPm, null);
    }

    public SemEstimator(CovarianceMatrix covMatrix, SemPm semPm) {
        this(covMatrix, semPm, null);
    }

    public SemEstimator(DataSet dataSet, SemPm semPm, SemOptimizer semOptimizer) {
        this(new CovarianceMatrix(dataSet), semPm, semOptimizer);
        this.setDataSet(this.subset(dataSet, semPm));
    }

    public SemEstimator(CovarianceMatrix covMatrix, SemPm semPm, SemOptimizer semOptimizer) {
        if (covMatrix == null) {
            throw new NullPointerException("CovarianceMatrix must not be null.");
        }
        if (semPm == null) {
            throw new NullPointerException("SemPm must not be null.");
        }
        this.setCovMatrix(this.submatrix(covMatrix, semPm));
        this.setSemPm(semPm);
        this.setSemOptimizer(semOptimizer);
    }

    public static SemEstimator serializableInstance() {
        return new SemEstimator(CovarianceMatrix.serializableInstance(), SemPm.serializableInstance());
    }

    public SemIm estimate() {
        this.setEstimatedSem(null);
        SemIm semIm = new SemIm(this.getSemPm(), this.getCovMatrix());
        GraphUtils.arrangeBySourceGraph(semIm.getSemPm().getGraph(), this.getSemPm().getGraph());
        semIm.setParameterBoundsEnforced(false);
        if (semIm.getSemPm().getGraph().getNumNodes() == 1) {
            double[] params = new double[]{this.getCovMatrix().getValue(0, 0)};
            semIm.setFreeParamValues(params);
        } else if (this.getSemOptimizer() == null) {
            this.doDefaultOptimization(semIm);
        } else {
            this.getSemOptimizer().optimize(semIm);
        }
        semIm.setParameterBoundsEnforced(true);
        this.setMeans(semIm, this.getDataSet());
        semIm.setEstimated(true);
        this.setEstimatedSem(semIm);
        return this.estimatedSem;
    }

    private void setCovMatrix(CovarianceMatrix covMatrix) {
        this.covMatrix = covMatrix;
    }

    public SemIm getEstimatedSem() {
        return this.estimatedSem;
    }

    public DataSet getDataSet() {
        return this.dataSet;
    }

    public SemPm getSemPm() {
        return this.semPm;
    }

    public CovarianceMatrix getCovMatrix() {
        return this.covMatrix;
    }

    public SemOptimizer getSemOptimizer() {
        return this.semOptimizer;
    }

    public SemIm getTrueSemIm() {
        return this.trueSemIm;
    }

    public void setTrueSemIm(SemIm semIm) {
        this.trueSemIm = new SemIm(semIm);
        this.trueSemIm.setCovMatrix(this.getCovMatrix());
    }

    public String toString() {
        StringBuilder buf = new StringBuilder();
        buf.append("\nSemEstimator");
        if (this.getEstimatedSem() == null) {
            buf.append("\n\t...SemIm has not been estimated yet.");
        } else {
            SemIm sem = this.getEstimatedSem();
            buf.append("\n\n\tfml = ");
            buf.append("\n\n\tmeasuredNodes:\n");
            buf.append("\t").append(sem.getMeasuredNodes());
            buf.append("\n\n\tedgeCoef:\n");
            buf.append(MatrixUtils.toString(sem.getEdgeCoef().toArray()));
            buf.append("\n\n\terrCovar:\n");
            buf.append(MatrixUtils.toString(sem.getErrCovar().toArray()));
        }
        return buf.toString();
    }

    private void doDefaultOptimization(SemIm semIm) {
        boolean containsLatent = false;
        for (Node node : this.getSemPm().getGraph().getNodes()) {
            if (node.getNodeType() != NodeType.LATENT) continue;
            containsLatent = true;
        }
        SemOptimizer optimizer = null;
        optimizer = this.containsFixedParam() || this.getSemPm().getGraph().existsDirectedCycle() || SemEstimator.containsCovarParam(this.getSemPm()) ? new SemOptimizerPalCds() : (containsLatent ? new SemOptimizerEm() : new SemOptimizerRegression());
        optimizer.optimize(semIm);
        this.semOptimizer = optimizer;
    }

    private boolean containsFixedParam() {
        return new SemIm(this.getSemPm()).getNumFixedParams() > 0;
    }

    private CovarianceMatrix submatrix(CovarianceMatrix covMatrix, SemPm semPm) {
        String[] measuredVarNames = semPm.getMeasuredVarNames();
        try {
            return covMatrix.getSubmatrix(measuredVarNames);
        }
        catch (IllegalArgumentException e) {
            throw new RuntimeException("All of the variables from the data set must be in the SEM parameterized model.", e);
        }
    }

    private DataSet subset(DataSet dataSet, SemPm semPm) {
        String[] measuredVarNames = semPm.getMeasuredVarNames();
        int[] varIndices = new int[measuredVarNames.length];
        for (int i = 0; i < measuredVarNames.length; ++i) {
            Node variable = dataSet.getVariable(measuredVarNames[i]);
            varIndices[i] = dataSet.getVariables().indexOf(variable);
        }
        return dataSet.subsetColumns(varIndices);
    }

    private static boolean containsCovarParam(SemPm semPm) {
        boolean containsCovarParam = false;
        List<Parameter> params = semPm.getParameters();
        for (Parameter param : params) {
            if (param.getType() != ParamType.COVAR) continue;
            containsCovarParam = true;
            break;
        }
        return containsCovarParam;
    }

    private void setMeans(SemIm semIm, DataSet dataSet) {
        block3: {
            block2: {
                if (dataSet == null) break block2;
                int numColumns = dataSet.getNumColumns();
                for (int j = 0; j < numColumns; ++j) {
                    double[] column = dataSet.getDoubleData().viewColumn(j).toArray();
                    DoubleArrayList list = new DoubleArrayList(column);
                    double mean = Descriptive.mean(list);
                    Node node = dataSet.getVariable(j);
                    Node variableNode = semIm.getVariableNode(node.getName());
                    semIm.setMean(variableNode, mean);
                    double standardDeviation = Descriptive.standardDeviation(Descriptive.variance(list.size(), Descriptive.sum(list), Descriptive.sumOfSquares(list)));
                    semIm.setMeanStandardDeviation(variableNode, standardDeviation);
                }
                break block3;
            }
            if (this.getCovMatrix() == null) break block3;
            List<Node> variables = this.getCovMatrix().getVariables();
            for (Node node : variables) {
                Node variableNode = semIm.getVariableNode(node.getName());
                semIm.setMean(variableNode, 0.0);
            }
        }
    }

    private void setSemOptimizer(SemOptimizer semOptimizer) {
        this.semOptimizer = semOptimizer;
    }

    private void setEstimatedSem(SemIm estimatedSem) {
        this.estimatedSem = estimatedSem;
    }

    private void setSemPm(SemPm semPm) {
        this.semPm = semPm;
    }

    private void setDataSet(DataSet dataSet) {
        this.dataSet = dataSet;
    }

    private void readObject(ObjectInputStream s) throws IOException, ClassNotFoundException {
        s.defaultReadObject();
        if (this.getCovMatrix() == null) {
            throw new NullPointerException();
        }
        if (this.getSemPm() == null) {
            throw new NullPointerException();
        }
    }

    public void setCheckPositiveDefinite(boolean checkPositiveDefinite) {
        this.checkPositiveDefinite = checkPositiveDefinite;
    }
}

