/*
 * Decompiled with CFR 0.152.
 */
package edu.cmu.tetrad.sem;

import edu.cmu.tetrad.data.BoxDataSet;
import edu.cmu.tetrad.data.CovarianceMatrix;
import edu.cmu.tetrad.data.DataSet;
import edu.cmu.tetrad.data.DataUtils;
import edu.cmu.tetrad.data.ICovarianceMatrix;
import edu.cmu.tetrad.data.VerticalDoubleDataBox;
import edu.cmu.tetrad.graph.LayoutUtil;
import edu.cmu.tetrad.graph.Node;
import edu.cmu.tetrad.graph.NodeType;
import edu.cmu.tetrad.sem.ParamType;
import edu.cmu.tetrad.sem.Parameter;
import edu.cmu.tetrad.sem.ScoreType;
import edu.cmu.tetrad.sem.SemIm;
import edu.cmu.tetrad.sem.SemOptimizer;
import edu.cmu.tetrad.sem.SemOptimizerEm;
import edu.cmu.tetrad.sem.SemOptimizerPowell;
import edu.cmu.tetrad.sem.SemOptimizerRegression;
import edu.cmu.tetrad.sem.SemPm;
import edu.cmu.tetrad.util.MatrixUtils;
import edu.cmu.tetrad.util.NumberFormatUtil;
import edu.cmu.tetrad.util.StatUtils;
import edu.cmu.tetrad.util.TetradLogger;
import edu.cmu.tetrad.util.TetradSerializable;
import java.io.IOException;
import java.io.ObjectInputStream;
import java.text.NumberFormat;
import java.util.ArrayList;
import java.util.List;

public final class SemEstimator
implements TetradSerializable {
    static final long serialVersionUID = 23L;
    private SemPm semPm;
    private ICovarianceMatrix covMatrix;
    private SemOptimizer semOptimizer;
    private SemIm estimatedSem;
    private DataSet dataSet;
    private ScoreType scoreType = ScoreType.Fgls;
    private int numRestarts = 1;

    public SemEstimator(DataSet dataSet, SemPm semPm) {
        this(dataSet, semPm, null);
    }

    public SemEstimator(ICovarianceMatrix covMatrix, SemPm semPm) {
        this(covMatrix, semPm, null);
    }

    public SemEstimator(DataSet dataSet, SemPm semPm, SemOptimizer semOptimizer) {
        this(new CovarianceMatrix(dataSet), semPm, semOptimizer);
        if (DataUtils.containsMissingValue(dataSet)) {
            throw new IllegalArgumentException("Expecting a data set with no missing values.");
        }
        this.setDataSet(this.subset(dataSet, semPm));
    }

    public SemEstimator(ICovarianceMatrix covMatrix, SemPm semPm, SemOptimizer semOptimizer) {
        if (covMatrix == null) {
            throw new NullPointerException("CovarianceMatrix must not be null.");
        }
        if (semPm == null) {
            throw new NullPointerException("SemPm must not be null.");
        }
        if (DataUtils.containsMissingValue(covMatrix.getMatrix())) {
            throw new IllegalArgumentException("Expecting a covariance matrix with no missing values.");
        }
        semPm.getGraph().setShowErrorTerms(false);
        this.setCovMatrix(this.submatrix(covMatrix, semPm));
        this.setSemPm(semPm);
        this.setSemOptimizer(semOptimizer);
    }

    public static SemEstimator serializableInstance() {
        return new SemEstimator(CovarianceMatrix.serializableInstance(), SemPm.serializableInstance());
    }

    public SemIm estimate() {
        if (this.getSemOptimizer() != null) {
            this.getSemOptimizer().setNumRestarts(this.numRestarts);
        }
        this.setEstimatedSem(null);
        SemIm semIm = new SemIm(this.getSemPm(), this.getCovMatrix());
        LayoutUtil.arrangeBySourceGraph(semIm.getSemPm().getGraph(), this.getSemPm().getGraph());
        semIm.setParameterBoundsEnforced(false);
        semIm.setScoreType(this.getScoreType());
        SemOptimizer defaultOptimizer = this.getDefaultOptimization(semIm);
        if (this.semOptimizer == null) {
            this.semOptimizer = defaultOptimizer;
        }
        this.getSemOptimizer().setNumRestarts(this.numRestarts);
        this.getSemOptimizer().optimize(semIm);
        semIm.setParameterBoundsEnforced(true);
        this.setMeans(semIm, this.getDataSet());
        semIm.setEstimated(true);
        this.setEstimatedSem(semIm);
        NumberFormat nf = NumberFormatUtil.getInstance().getNumberFormat();
        TetradLogger.getInstance().log("stats", "Sample Size = " + semIm.getSampleSize());
        TetradLogger.getInstance().log("stats", "Model Chi Square = " + nf.format(semIm.getChiSquare()));
        TetradLogger.getInstance().log("stats", "Model DOF = " + nf.format(this.semPm.getDof()));
        TetradLogger.getInstance().log("stats", "Model P Value = " + nf.format(semIm.getPValue()));
        TetradLogger.getInstance().log("stats", "Model BIC = " + nf.format(semIm.getBicScore()));
        System.out.println(this.estimatedSem);
        return this.estimatedSem;
    }

    private void setCovMatrix(ICovarianceMatrix covMatrix) {
        this.covMatrix = covMatrix;
    }

    public SemIm getEstimatedSem() {
        return this.estimatedSem;
    }

    public DataSet getDataSet() {
        return this.dataSet;
    }

    public SemPm getSemPm() {
        return this.semPm;
    }

    public ICovarianceMatrix getCovMatrix() {
        return this.covMatrix;
    }

    private SemOptimizer getSemOptimizer() {
        return this.semOptimizer;
    }

    public String toString() {
        StringBuilder buf = new StringBuilder();
        buf.append("\nSemEstimator");
        if (this.getEstimatedSem() == null) {
            buf.append("\n\t...SemIm has not been estimated yet.");
        } else {
            SemIm sem = this.getEstimatedSem();
            buf.append("\n\n\tfml = ");
            buf.append("\n\n\tmeasuredNodes:\n");
            buf.append("\t").append(sem.getMeasuredNodes());
            buf.append("\n\n\tedgeCoef:\n");
            buf.append(MatrixUtils.toString(sem.getEdgeCoef().toArray()));
            buf.append("\n\n\terrCovar:\n");
            buf.append(MatrixUtils.toString(sem.getErrCovar().toArray()));
        }
        return buf.toString();
    }

    private SemOptimizer getDefaultOptimization(SemIm semIm) {
        if (semIm == null) {
            throw new NullPointerException();
        }
        boolean containsLatent = false;
        for (Node node : this.getSemPm().getGraph().getNodes()) {
            if (node.getNodeType() != NodeType.LATENT) continue;
            containsLatent = true;
            break;
        }
        SemOptimizer optimizer = this.containsFixedParam() || this.getSemPm().getGraph().paths().existsDirectedCycle() || SemEstimator.containsCovarParam(this.getSemPm()) ? new SemOptimizerPowell() : (containsLatent ? new SemOptimizerEm() : new SemOptimizerRegression());
        optimizer.setNumRestarts(this.numRestarts);
        return optimizer;
    }

    private boolean containsFixedParam() {
        return new SemIm(this.getSemPm()).getNumFixedParams() > 0;
    }

    private ICovarianceMatrix submatrix(ICovarianceMatrix covMatrix, SemPm semPm) {
        String[] measuredVarNames = semPm.getMeasuredVarNames();
        try {
            return covMatrix.getSubmatrix(measuredVarNames);
        }
        catch (IllegalArgumentException e) {
            e.printStackTrace();
            throw new RuntimeException("All of the variables from the SEM parameterized model must be in the data set.", e);
        }
    }

    private DataSet subset(DataSet dataSet, SemPm semPm) {
        String[] measuredVarNames = semPm.getMeasuredVarNames();
        int[] varIndices = new int[measuredVarNames.length];
        List<Node> dataVars = dataSet.getVariables();
        for (int i = 0; i < measuredVarNames.length; ++i) {
            Node variable = dataSet.getVariable(measuredVarNames[i]);
            varIndices[i] = dataVars.indexOf(variable);
        }
        return dataSet.subsetColumns(varIndices);
    }

    private static boolean containsCovarParam(SemPm semPm) {
        boolean containsCovarParam = false;
        List<Parameter> params = semPm.getParameters();
        for (Parameter param : params) {
            if (param.getType() != ParamType.COVAR) continue;
            containsCovarParam = true;
            break;
        }
        return containsCovarParam;
    }

    private void setMeans(SemIm semIm, DataSet dataSet) {
        block3: {
            block2: {
                if (dataSet == null) break block2;
                int numColumns = dataSet.getNumColumns();
                for (int j = 0; j < numColumns; ++j) {
                    double[] column = dataSet.getDoubleData().getColumn(j).toArray();
                    double mean = StatUtils.mean(column);
                    Node node = dataSet.getVariable(j);
                    Node variableNode = semIm.getVariableNode(node.getName());
                    semIm.setMean(variableNode, mean);
                    double standardDeviation = StatUtils.sd(column);
                    semIm.setMeanStandardDeviation(variableNode, standardDeviation);
                }
                break block3;
            }
            if (this.getCovMatrix() == null) break block3;
            List<Node> variables = this.getCovMatrix().getVariables();
            for (Node node : variables) {
                Node variableNode = semIm.getVariableNode(node.getName());
                semIm.setMean(variableNode, 0.0);
            }
        }
    }

    public void setSemOptimizer(SemOptimizer semOptimizer) {
        this.semOptimizer = semOptimizer;
    }

    private void setEstimatedSem(SemIm estimatedSem) {
        this.estimatedSem = estimatedSem;
    }

    private void setSemPm(SemPm semPm) {
        this.semPm = semPm;
    }

    private void setDataSet(DataSet dataSet) {
        List<Node> nodes1 = this.semPm.getMeasuredNodes();
        ArrayList<Node> vars = new ArrayList<Node>();
        for (Node node : nodes1) {
            Node _node = dataSet.getVariable(node.getName());
            vars.add(_node);
        }
        BoxDataSet _dataSet = new BoxDataSet(new VerticalDoubleDataBox(dataSet.getDoubleData().transpose().toArray()), vars);
        _dataSet.setName(dataSet.getName());
        this.dataSet = _dataSet;
    }

    private void readObject(ObjectInputStream s) throws IOException, ClassNotFoundException {
        s.defaultReadObject();
        if (this.getCovMatrix() == null) {
            throw new NullPointerException();
        }
        if (this.getSemPm() == null) {
            throw new NullPointerException();
        }
    }

    public void setScoreType(ScoreType scoreType) {
        this.scoreType = scoreType;
    }

    private ScoreType getScoreType() {
        return this.scoreType;
    }

    public void setNumRestarts(int numRestarts) {
        this.numRestarts = numRestarts;
    }
}

