/*
 * Decompiled with CFR 0.152.
 */
package edu.cmu.tetrad.bayes;

import edu.cmu.tetrad.bayes.BayesIm;
import edu.cmu.tetrad.bayes.BayesImProbs;
import edu.cmu.tetrad.bayes.BayesPm;
import edu.cmu.tetrad.bayes.EmBayesEstimator;
import edu.cmu.tetrad.bayes.MlBayesIm;
import edu.cmu.tetrad.data.DataSet;
import edu.cmu.tetrad.data.DiscreteVariable;
import edu.cmu.tetrad.graph.Dag;
import edu.cmu.tetrad.graph.Graph;
import edu.cmu.tetrad.graph.GraphNode;
import edu.cmu.tetrad.graph.Node;
import edu.cmu.tetrad.util.ProbUtils;
import java.util.HashMap;
import java.util.List;
import org.apache.commons.math3.util.FastMath;

public final class EmBayesProperties {
    private DataSet dataSet;
    private BayesPm bayesPm;
    private Graph graph;
    private MlBayesIm blankBayesIm;
    private int pValueDf;
    private double chisq;
    private Estimator estimator = (bayesPm, dataSet) -> {
        EmBayesEstimator estimator = new EmBayesEstimator(bayesPm, dataSet);
        this.dataSet = estimator.getMixedDataSet();
        try {
            double tolerance = 1.0E-4;
            estimator.maximization(tolerance);
            return estimator.getEstimatedIm();
        }
        catch (IllegalArgumentException e) {
            e.printStackTrace();
            throw new RuntimeException("Please specify the search tolerance first.");
        }
    };

    public EmBayesProperties(DataSet dataSet2, Graph graph) {
        this.setDataSet(dataSet2);
        this.setGraph(graph);
    }

    public void setGraph(Graph graph) {
        if (graph == null) {
            throw new NullPointerException();
        }
        List<Node> vars = this.dataSet.getVariables();
        HashMap<String, DiscreteVariable> nodesToVars = new HashMap<String, DiscreteVariable>();
        for (int i = 0; i < this.dataSet.getNumColumns(); ++i) {
            DiscreteVariable var = (DiscreteVariable)vars.get(i);
            String name = var.getName();
            GraphNode node = new GraphNode(name);
            nodesToVars.put(node.getName(), var);
        }
        Dag dag = new Dag(graph);
        BayesPm bayesPm = new BayesPm(dag);
        List<Node> nodes = bayesPm.getDag().getNodes();
        for (Node node1 : nodes) {
            DiscreteVariable var = (DiscreteVariable)nodesToVars.get(node1.getName());
            if (var == null) continue;
            List<String> categories = var.getCategories();
            bayesPm.setCategories(node1, categories);
        }
        this.graph = graph;
        this.bayesPm = bayesPm;
        this.blankBayesIm = new MlBayesIm(bayesPm);
    }

    public double getBic() {
        return this.logProbDataGivenStructure() - this.parameterPenalty();
    }

    public double getLikelihoodRatioP() {
        Graph graph1 = this.getGraph();
        List<Node> nodes = this.getGraph().getNodes();
        Dag graph0 = new Dag();
        for (Node node : nodes) {
            graph0.addNode(node);
        }
        EmBayesProperties scorer1 = new EmBayesProperties(this.getDataSet(), graph1);
        EmBayesProperties scorer0 = new EmBayesProperties(this.getDataSet(), graph0);
        double l1 = scorer1.logProbDataGivenStructure();
        double l0 = scorer0.logProbDataGivenStructure();
        System.out.println("l1 = " + l1);
        System.out.println("l0 = " + l0);
        double chisq = -2.0 * (l0 - l1);
        int n1 = scorer1.numNonredundantParams();
        int n0 = scorer0.numNonredundantParams();
        int df = n1 - n0;
        double pValue = 1.0 - ProbUtils.chisqCdf(chisq, df);
        this.pValueDf = df;
        this.chisq = chisq;
        return pValue;
    }

    public BayesPm getBayesPm() {
        return this.bayesPm;
    }

    public int getPValueDf() {
        return this.pValueDf;
    }

    public double getPValueChisq() {
        return this.chisq;
    }

    public Estimator getEstimator() {
        return this.estimator;
    }

    public void setEstimator(Estimator estimator) {
        this.estimator = estimator;
    }

    private double logProbDataGivenStructure() {
        BayesIm bayesIm = this.estimator.estimate(this.bayesPm, this.dataSet);
        BayesImProbs probs = new BayesImProbs(bayesIm);
        List<Node> variables = bayesIm.getVariables();
        System.out.println("E1 bayesIm : " + variables);
        System.out.println("E2 data set : " + this.dataSet.getVariables());
        DataSet reorderedDataSet = this.dataSet.subsetColumns(variables);
        int n = reorderedDataSet.getNumRows();
        int m = reorderedDataSet.getNumColumns();
        double score = 0.0;
        int[] _case = new int[m];
        for (int i = 0; i < n; ++i) {
            for (int j = 0; j < m; ++j) {
                _case[j] = reorderedDataSet.getInt(i, j);
            }
            score += FastMath.log(probs.getCellProb(_case));
        }
        return score;
    }

    private int numNonredundantParams() {
        this.setGraph(this.getGraph());
        int numParams = 0;
        for (int j = 0; j < this.blankBayesIm.getNumNodes(); ++j) {
            int numColumns = this.blankBayesIm.getNumColumns(j);
            int numRows = this.blankBayesIm.getNumRows(j);
            if (numColumns <= 1) continue;
            numParams += (numColumns - 1) * numRows;
        }
        return numParams;
    }

    private double parameterPenalty() {
        int numParams = this.numNonredundantParams();
        double r = this.dataSet.getNumRows();
        return (double)numParams * FastMath.log(r) / 2.0;
    }

    private Graph getGraph() {
        return this.graph;
    }

    private DataSet getDataSet() {
        return this.dataSet;
    }

    private void setDataSet(DataSet dataSet) {
        if (dataSet == null) {
            throw new NullPointerException();
        }
        this.bayesPm = null;
        this.blankBayesIm = null;
        this.graph = null;
        this.pValueDf = -1;
        this.chisq = Double.NaN;
        this.dataSet = dataSet;
    }

    public static interface Estimator {
        public BayesIm estimate(BayesPm var1, DataSet var2);
    }
}

