/*
 * Decompiled with CFR 0.152.
 */
package edu.cmu.tetrad.sem;

import edu.cmu.tetrad.data.CovarianceMatrix;
import edu.cmu.tetrad.data.DataSet;
import edu.cmu.tetrad.data.ICovarianceMatrix;
import edu.cmu.tetrad.graph.Graph;
import edu.cmu.tetrad.graph.Node;
import edu.cmu.tetrad.graph.NodeType;
import edu.cmu.tetrad.sem.Scorer;
import edu.cmu.tetrad.sem.SemEstimator;
import edu.cmu.tetrad.sem.SemIm;
import edu.cmu.tetrad.sem.SemOptimizer;
import edu.cmu.tetrad.sem.SemOptimizerRegression;
import edu.cmu.tetrad.sem.SemPm;
import edu.cmu.tetrad.util.Matrix;
import edu.cmu.tetrad.util.MatrixUtils;
import edu.cmu.tetrad.util.ProbUtils;
import edu.cmu.tetrad.util.TetradSerializable;
import edu.cmu.tetrad.util.Vector;
import java.io.IOException;
import java.io.ObjectInputStream;
import java.util.ArrayList;
import java.util.HashSet;
import java.util.List;
import java.util.TreeSet;
import org.apache.commons.math3.util.FastMath;

public final class DagScorer
implements TetradSerializable,
Scorer {
    private static final long serialVersionUID = 23L;
    private final ICovarianceMatrix covMatrix;
    private final Matrix edgeCoef;
    private final Matrix errorCovar;
    private final List<Node> variables;
    private final Matrix sampleCovar;
    private DataSet dataSet;
    private Graph dag;
    private Matrix implCovarMeasC;
    private double logDetSample;
    private double fml = Double.NaN;

    public DagScorer(DataSet dataSet) {
        this(new CovarianceMatrix(dataSet));
        this.dataSet = dataSet;
    }

    public DagScorer(ICovarianceMatrix covMatrix) {
        if (covMatrix == null) {
            throw new NullPointerException("CovarianceMatrix must not be null.");
        }
        this.variables = covMatrix.getVariables();
        this.covMatrix = covMatrix;
        int m = this.getVariables().size();
        this.edgeCoef = new Matrix(m, m);
        this.errorCovar = new Matrix(m, m);
        this.sampleCovar = covMatrix.getMatrix();
    }

    public static Scorer serializableInstance() {
        return new DagScorer(CovarianceMatrix.serializableInstance());
    }

    @Override
    public double score(Graph dag) {
        List<Node> changedNodes = this.getChangedNodes(dag);
        for (Node node : changedNodes) {
            int i1 = this.indexOf(node);
            this.getErrorCovar().set(i1, i1, 0.0);
            for (int _j = 0; _j < this.getVariables().size(); ++_j) {
                this.getEdgeCoef().set(_j, i1, 0.0);
            }
            if (node.getNodeType() != NodeType.MEASURED) continue;
            int idx = this.indexOf(node);
            ArrayList<Node> parents = new ArrayList<Node>(dag.getParents(node));
            for (int i = 0; i < parents.size(); ++i) {
                Node nextParent = (Node)parents.get(i);
                if (nextParent.getNodeType() != NodeType.ERROR) continue;
                parents.remove(nextParent);
                break;
            }
            double variance = this.getSampleCovar().get(idx, idx);
            if (parents.size() > 0) {
                Vector nodeParentsCov = new Vector(parents.size());
                Matrix parentsCov = new Matrix(parents.size(), parents.size());
                for (int i = 0; i < parents.size(); ++i) {
                    int idx2 = this.indexOf((Node)parents.get(i));
                    nodeParentsCov.set(i, this.getSampleCovar().get(idx, idx2));
                    for (int j = i; j < parents.size(); ++j) {
                        int idx3 = this.indexOf((Node)parents.get(j));
                        parentsCov.set(i, j, this.getSampleCovar().get(idx2, idx3));
                        parentsCov.set(j, i, this.getSampleCovar().get(idx3, idx2));
                    }
                }
                Vector edges = parentsCov.inverse().times(nodeParentsCov);
                for (int i = 0; i < edges.size(); ++i) {
                    int idx2 = this.indexOf((Node)parents.get(i));
                    this.edgeCoef.set(idx2, this.indexOf(node), edges.get(i));
                }
                variance -= nodeParentsCov.dotProduct(edges);
            }
            this.errorCovar.set(i1, i1, variance);
        }
        this.dag = dag;
        this.fml = Double.NaN;
        return this.getFml();
    }

    private int indexOf(Node node) {
        for (int i = 0; i < this.getVariables().size(); ++i) {
            if (!node.getName().equals(this.getVariables().get(i).getName())) continue;
            return i;
        }
        throw new IllegalArgumentException("Dag must have the same nodes as the data.");
    }

    private List<Node> getChangedNodes(Graph dag) {
        if (this.dag == null) {
            return dag.getNodes();
        }
        if (!new HashSet<Node>(this.getVariables()).equals(new HashSet<Node>(dag.getNodes()))) {
            System.out.println(new TreeSet<Node>(dag.getNodes()));
            System.out.println(new TreeSet<Node>(this.variables));
            throw new IllegalArgumentException("Dag must have the same nodes as the data.");
        }
        ArrayList<Node> changedNodes = new ArrayList<Node>();
        for (Node node : dag.getNodes()) {
            if (new HashSet<Node>(this.dag.getParents(node)).equals(new HashSet<Node>(dag.getParents(node)))) continue;
            changedNodes.add(node);
        }
        return changedNodes;
    }

    @Override
    public ICovarianceMatrix getCovMatrix() {
        return this.covMatrix;
    }

    @Override
    public String toString() {
        return "\nSemEstimator";
    }

    @Override
    public double getFml() {
        Matrix implCovarMeas;
        if (!Double.isNaN(this.fml)) {
            return this.fml;
        }
        try {
            implCovarMeas = this.implCovarMeas();
        }
        catch (Exception e) {
            e.printStackTrace();
            return Double.NaN;
        }
        Matrix sampleCovar = this.sampleCovar();
        double logDetSigma = this.logDet(implCovarMeas);
        double traceSSigmaInv = this.traceABInv(sampleCovar, implCovarMeas);
        double logDetSample = this.logDetSample();
        int pPlusQ = this.getMeasuredNodes().size();
        double fml = logDetSigma + traceSSigmaInv - logDetSample - (double)pPlusQ;
        if (FastMath.abs(fml) < 0.0) {
            fml = 0.0;
        }
        this.fml = fml;
        return fml;
    }

    private Matrix sampleCovar() {
        return this.getSampleCovar();
    }

    private Matrix implCovarMeas() {
        this.computeImpliedCovar();
        return this.implCovarMeasC;
    }

    @Override
    public double getBicScore() {
        int dof = this.getDof();
        return this.getChiSquare() - (double)dof * FastMath.log(this.getSampleSize());
    }

    @Override
    public double getChiSquare() {
        return (double)(this.getSampleSize() - 1) * this.getFml();
    }

    @Override
    public double getPValue() {
        return 1.0 - ProbUtils.chisqCdf(this.getChiSquare(), this.getDof());
    }

    private void readObject(ObjectInputStream s) throws IOException, ClassNotFoundException {
        s.defaultReadObject();
        if (this.getCovMatrix() == null) {
            throw new NullPointerException();
        }
    }

    private void computeImpliedCovar() {
        Matrix implCovarC = MatrixUtils.impliedCovar(this.edgeCoef().transpose(), this.errCovar());
        int size = this.getMeasuredNodes().size();
        this.implCovarMeasC = new Matrix(size, size);
        for (int i = 0; i < size; ++i) {
            for (int j = 0; j < size; ++j) {
                this.implCovarMeasC.set(i, j, implCovarC.get(i, j));
            }
        }
    }

    private Matrix errCovar() {
        return this.getErrorCovar();
    }

    private Matrix edgeCoef() {
        return this.getEdgeCoef();
    }

    private double logDet(Matrix matrix2D) {
        return FastMath.log(matrix2D.det());
    }

    private double traceAInvB(Matrix A, Matrix B) {
        Matrix inverse = A.inverse();
        Matrix product = inverse.times(B);
        double trace = product.trace();
        if (trace < -1.0E-8) {
            throw new IllegalArgumentException("Trace was negative: " + trace);
        }
        return trace;
    }

    private double traceABInv(Matrix A, Matrix B) {
        try {
            Matrix product = A.times(B.inverse());
            double trace = product.trace();
            if (trace < -1.0E-8) {
                throw new IllegalArgumentException("Trace was negative: " + trace);
            }
            return trace;
        }
        catch (Exception e) {
            System.out.println(B);
            throw new RuntimeException(e);
        }
    }

    private double logDetSample() {
        if (this.logDetSample == 0.0 && this.sampleCovar() != null) {
            double det = this.sampleCovar().det();
            this.logDetSample = FastMath.log(det);
        }
        return this.logDetSample;
    }

    @Override
    public DataSet getDataSet() {
        return this.dataSet;
    }

    @Override
    public int getNumFreeParams() {
        return this.dag.getEdges().size() + this.dag.getNodes().size();
    }

    @Override
    public int getDof() {
        return this.dag.getNodes().size() * (this.dag.getNodes().size() + 1) / 2 - this.getNumFreeParams();
    }

    @Override
    public int getSampleSize() {
        return this.covMatrix.getSampleSize();
    }

    @Override
    public List<Node> getMeasuredNodes() {
        return this.getVariables();
    }

    @Override
    public Matrix getSampleCovar() {
        return this.sampleCovar;
    }

    @Override
    public Matrix getEdgeCoef() {
        return this.edgeCoef;
    }

    @Override
    public Matrix getErrorCovar() {
        return this.errorCovar;
    }

    @Override
    public List<Node> getVariables() {
        return this.variables;
    }

    @Override
    public SemIm getEstSem() {
        SemPm pm = new SemPm(this.dag);
        if (this.dataSet != null) {
            return new SemEstimator(this.dataSet, pm, (SemOptimizer)new SemOptimizerRegression()).estimate();
        }
        if (this.covMatrix != null) {
            return new SemEstimator(this.covMatrix, pm, (SemOptimizer)new SemOptimizerRegression()).estimate();
        }
        throw new IllegalStateException();
    }
}

