/*
 * Decompiled with CFR 0.152.
 */
package edu.cmu.tetrad.sem;

import edu.cmu.tetrad.graph.Node;
import edu.cmu.tetrad.graph.NodeType;
import edu.cmu.tetrad.graph.SemGraph;
import edu.cmu.tetrad.sem.SemIm;
import edu.cmu.tetrad.sem.SemOptimizer;
import edu.cmu.tetrad.util.Matrix;
import edu.cmu.tetrad.util.TetradLogger;
import java.util.ArrayList;
import java.util.Comparator;
import org.jetbrains.annotations.NotNull;

public class SemOptimizerRegression
implements SemOptimizer {
    private static final long serialVersionUID = 23L;
    private int numRestarts = 1;

    public static SemOptimizerRegression serializableInstance() {
        return new SemOptimizerRegression();
    }

    private static int[] indexedParents(int[] parents) {
        int[] pp = new int[parents.length];
        for (int j = 0; j < pp.length; ++j) {
            pp[j] = j + 1;
        }
        return pp;
    }

    @NotNull
    private static Matrix bStar(Matrix b) {
        Matrix byx = new Matrix(b.getNumRows() + 1, 1);
        byx.set(0, 0, 1.0);
        for (int j = 0; j < b.getNumRows(); ++j) {
            byx.set(j + 1, 0, -b.get(j, 0));
        }
        return byx;
    }

    private static int[] concat(int i, int[] parents) {
        int[] all = new int[parents.length + 1];
        all[0] = i;
        System.arraycopy(parents, 0, all, 1, parents.length);
        return all;
    }

    private static Matrix getCov(int[] _rows, int[] cols, Matrix covarianceMatrix) {
        return covarianceMatrix.getSelection(_rows, cols);
    }

    @Override
    public void optimize(SemIm semIm) {
        if (this.numRestarts != 1) {
            throw new IllegalArgumentException("Number of restarts must be 1 for this method.");
        }
        Matrix covar = semIm.getSampleCovar();
        if (covar == null) {
            throw new NullPointerException("Sample covar has not been set.");
        }
        SemGraph graph = semIm.getSemPm().getGraph();
        graph.setShowErrorTerms(false);
        ArrayList<Node> nodes = new ArrayList<Node>(semIm.getVariableNodes());
        nodes.removeIf(node -> node.getNodeType() == NodeType.ERROR);
        TetradLogger.getInstance().forceLogMessage("FML = " + semIm.getScore());
        for (Node n : nodes) {
            int i = nodes.indexOf(n);
            ArrayList<Node> parents = new ArrayList<Node>(graph.getParents(n));
            parents.removeIf(parent -> parent.getNodeType() == NodeType.ERROR);
            parents.sort(Comparator.comparingInt(nodes::indexOf));
            int[] _parents = new int[parents.size()];
            for (int j = 0; j < parents.size(); ++j) {
                _parents[j] = nodes.indexOf(parents.get(j));
            }
            int[] all = SemOptimizerRegression.concat(i, _parents);
            Matrix cov = SemOptimizerRegression.getCov(all, all, covar);
            int[] pp = SemOptimizerRegression.indexedParents(_parents);
            Matrix covxx = cov.getSelection(pp, pp);
            Matrix covxy = cov.getSelection(pp, new int[]{0});
            Matrix b = covxx.inverse().times(covxy);
            for (int j = 0; j < b.getNumRows(); ++j) {
                semIm.setParamValue((Node)parents.get(j), n, b.get(j, 0));
            }
            Matrix bStar = SemOptimizerRegression.bStar(b);
            double varry = bStar.transpose().times(cov).times(bStar).get(0, 0);
            semIm.setParamValue(n, n, varry);
        }
        TetradLogger.getInstance().log("optimization", "FML = " + semIm.getScore());
    }

    @Override
    public int getNumRestarts() {
        return this.numRestarts;
    }

    @Override
    public void setNumRestarts(int numRestarts) {
        this.numRestarts = numRestarts;
    }

    public String toString() {
        return "Sem Optimizer Regression";
    }
}

