/*
 * Decompiled with CFR 0.152.
 */
package edu.cmu.tetrad.bayes;

import edu.cmu.tetrad.bayes.BayesIm;
import edu.cmu.tetrad.bayes.BayesPm;
import edu.cmu.tetrad.bayes.BdeMetricCache;
import edu.cmu.tetrad.bayes.EmBayesEstimator;
import edu.cmu.tetrad.bayes.ModelGenerator;
import edu.cmu.tetrad.data.DataSet;
import edu.cmu.tetrad.data.DiscreteVariable;
import edu.cmu.tetrad.graph.Dag;
import edu.cmu.tetrad.graph.EdgeListGraph;
import edu.cmu.tetrad.graph.Graph;
import edu.cmu.tetrad.graph.Node;
import edu.cmu.tetrad.util.MillisecondTimes;
import edu.cmu.tetrad.util.TetradLogger;
import java.util.HashSet;
import java.util.List;

public final class FactoredBayesStructuralEM {
    private final BayesPm bayesPmM0;
    private final DataSet dataSet;
    private final int[] ncategories;
    private double tolerance;

    public FactoredBayesStructuralEM(DataSet dataSet, BayesPm bayesPmM0) {
        this.dataSet = dataSet;
        this.bayesPmM0 = bayesPmM0;
        List<Node> datasetVars = dataSet.getVariables();
        this.ncategories = new int[datasetVars.size()];
        for (int i = 0; i < this.ncategories.length; ++i) {
            this.ncategories[i] = ((DiscreteVariable)datasetVars.get(i)).getNumCategories();
        }
    }

    private static double factorScoreMD(Dag dag, BdeMetricCache bdeMetricCache, BayesPm bayesPm, BayesIm bayesIm) {
        List<Node> nodes = dag.getNodes();
        double score = 0.0;
        for (Node node1 : nodes) {
            List<Node> parents = dag.getParents(node1);
            HashSet<Node> parentsSet = new HashSet<Node>(parents);
            double fScore = bdeMetricCache.scoreLnGam(node1, parentsSet, bayesPm, bayesIm);
            TetradLogger.getInstance().log("details", "Score for factor " + node1.getName() + " = " + fScore);
            score += fScore;
        }
        return score;
    }

    public BayesIm maximization(double tolerance) {
        TetradLogger.getInstance().log("details", "FactoredBayesStructuralEM.maximization()");
        this.tolerance = tolerance;
        return this.iterate();
    }

    public BayesIm iterate() {
        double start = MillisecondTimes.timeMillis();
        BdeMetricCache bdeMetricCache = new BdeMetricCache(this.dataSet, this.bayesPmM0);
        BayesPm bayesPmMnplus1 = this.bayesPmM0;
        double oldBestScore = Double.NEGATIVE_INFINITY;
        boolean iteration = false;
        TimedIterate ti = new TimedIterate(bdeMetricCache, bayesPmMnplus1, Double.NEGATIVE_INFINITY, 0, start);
        Thread tithread = new Thread(ti);
        tithread.start();
        try {
            tithread.join();
        }
        catch (InterruptedException e) {
            e.printStackTrace();
        }
        BayesPm bayesPmMn = bayesPmMnplus1 = ti.bayesPmMnplus1;
        EmBayesEstimator emBayesEst = new EmBayesEstimator(bayesPmMn, this.dataSet);
        return emBayesEst.maximization(this.tolerance);
    }

    public void scoreTest() {
        TetradLogger.getInstance().log("details", "scoreTest");
        BayesPm bayesPmMn = this.bayesPmM0;
        EmBayesEstimator emBayesEst = new EmBayesEstimator(bayesPmMn, this.dataSet);
        emBayesEst.maximization(1.0E-4);
        Dag dag0 = new Dag(bayesPmMn.getDag());
        Node L1 = dag0.getNode("L1");
        Node X1 = dag0.getNode("X1");
        Dag dag1 = new Dag(dag0);
        dag1.addDirectedEdge(X1, L1);
        BayesPm bayesPm0 = new BayesPm(dag0);
        EmBayesEstimator emBayesEst0 = new EmBayesEstimator(bayesPm0, this.dataSet);
        BayesIm bayesImMn0 = emBayesEst0.maximization(1.0E-4);
        BayesPm bayesPmTest0 = new BayesPm(dag0);
        TetradLogger.getInstance().log("details", "Observed conts for nodes of L1,X1,X2,X3 (no edges) using the MAP parameters based on that same graph");
        TetradLogger.getInstance().log("details", "Graph of PM:  ");
        TetradLogger.getInstance().log("details", "" + bayesPmTest0.getDag());
        TetradLogger.getInstance().log("details", "Graph of IM:  ");
        TetradLogger.getInstance().log("details", "" + bayesImMn0.getBayesPm().getDag());
        BdeMetricCache bdeMetricCache = new BdeMetricCache(this.dataSet, bayesPmTest0);
        List<Node> nodes0 = dag0.getNodes();
        for (Node aNodes0 : nodes0) {
            double[][] counts0;
            for (double[] aCounts0 : counts0 = bdeMetricCache.getObservedCounts(aNodes0, bayesPmTest0, bayesImMn0)) {
                for (int j = 0; j < counts0[0].length; ++j) {
                    System.out.print(" " + aCounts0[j]);
                }
                TetradLogger.getInstance().log("details", "\n");
            }
            TetradLogger.getInstance().log("details", "\n");
        }
        double score0 = FactoredBayesStructuralEM.factorScoreMD(dag0, bdeMetricCache, bayesPmTest0, bayesImMn0);
        TetradLogger.getInstance().log("details", "Score of L1,X1,X2,X3 (no edges) for itself = " + score0);
        TetradLogger.getInstance().log("details", "===============\n\n");
        TetradLogger.getInstance().log("details", "Score of X1-->L1 for L1,X1,X2,X3 (no edges) = " + score0);
        BayesPm bayesPmTest1 = new BayesPm(dag1);
        TetradLogger.getInstance().log("details", "Observed counts for nodes of X1-->L1 for L1,X1,X2,X3 (no edges)");
        TetradLogger.getInstance().log("details", "Graph of PM :  ");
        TetradLogger.getInstance().log("details", "" + bayesPmTest1.getDag());
        TetradLogger.getInstance().log("details", "Graph of IM:  ");
        TetradLogger.getInstance().log("details", "" + bayesImMn0.getBayesPm().getDag());
        bdeMetricCache = new BdeMetricCache(this.dataSet, bayesPmTest1);
        List<Node> nodes1 = dag0.getNodes();
        for (Node aNodes1 : nodes1) {
            double[][] counts1;
            for (double[] aCounts1 : counts1 = bdeMetricCache.getObservedCounts(aNodes1, bayesPmTest1, bayesImMn0)) {
                for (int j = 0; j < counts1[0].length; ++j) {
                    TetradLogger.getInstance().log("details", " " + aCounts1[j]);
                }
                TetradLogger.getInstance().log("details", "\n");
            }
            TetradLogger.getInstance().log("details", "\n");
        }
        double score1 = FactoredBayesStructuralEM.factorScoreMD(dag1, bdeMetricCache, bayesPmTest1, bayesImMn0);
        TetradLogger.getInstance().log("details", "Score of X1-->L1 for L1,X1,X2,X3 (no edges) = " + score1);
    }

    public DataSet getDataSet() {
        return this.dataSet;
    }

    private class TimedIterate
    implements Runnable {
        final BdeMetricCache bdeMetricCache;
        final double start;
        BayesPm bayesPmMnplus1;
        BayesPm bayesPmMn;
        double oldBestScore;
        int iteration;

        public TimedIterate(BdeMetricCache bdeMetricCache, BayesPm bayesPmMnplus1, double oldBestScore, int iteration, double start) {
            this.bdeMetricCache = bdeMetricCache;
            this.bayesPmMnplus1 = bayesPmMnplus1;
            this.bayesPmMn = null;
            this.oldBestScore = oldBestScore;
            this.iteration = iteration;
            this.start = start;
        }

        @Override
        public void run() {
            while (!this.bayesPmMnplus1.equals(this.bayesPmMn)) {
                ++this.iteration;
                this.bayesPmMn = this.bayesPmMnplus1;
                TetradLogger.getInstance().log("details", "In Factored Bayes Struct EM Iteration number " + this.iteration);
                TetradLogger.getInstance().log("details", "Starting EM Bayes estimator to get MAP parameters of Mn");
                EmBayesEstimator emBayesEst = new EmBayesEstimator(this.bayesPmMn, FactoredBayesStructuralEM.this.dataSet);
                BayesIm bayesImMn = emBayesEst.maximization(FactoredBayesStructuralEM.this.tolerance);
                TetradLogger.getInstance().log("details", "Estimation of MAP parameters of Mn complete. \n\n");
                Graph graphMn = this.bayesPmMn.getDag();
                Dag dagMn = new Dag(graphMn);
                List<Graph> models = ModelGenerator.generate(graphMn);
                double bestScore = FactoredBayesStructuralEM.factorScoreMD(dagMn, this.bdeMetricCache, this.bayesPmMn, bayesImMn);
                EdgeListGraph edges = new EdgeListGraph(dagMn);
                TetradLogger.getInstance().log("details", "Initial graph Mn = ");
                TetradLogger.getInstance().log("details", edges.toString());
                TetradLogger.getInstance().log("details", "Its score = " + bestScore);
                for (Graph model : models) {
                    Dag dag = new Dag(model);
                    BayesPm bayesPmTest = new BayesPm(dag);
                    for (int i = 0; i < FactoredBayesStructuralEM.this.dataSet.getVariables().size(); ++i) {
                        String varName = FactoredBayesStructuralEM.this.dataSet.getVariableNames().get(i);
                        Node node = dag.getNode(varName);
                        bayesPmTest.setNumCategories(node, FactoredBayesStructuralEM.this.ncategories[i]);
                    }
                    double score = FactoredBayesStructuralEM.factorScoreMD(dag, this.bdeMetricCache, bayesPmTest, bayesImMn);
                    EdgeListGraph edgesTest = new EdgeListGraph(dag);
                    TetradLogger.getInstance().log("details", "For the model with graph \n" + edgesTest);
                    TetradLogger.getInstance().log("details", "Model Score = " + score);
                    if (score <= bestScore) continue;
                    bestScore = score;
                    this.bayesPmMnplus1 = bayesPmTest;
                }
                TetradLogger.getInstance().log("details", "In iteration:  " + this.iteration);
                TetradLogger.getInstance().log("details", "bestScore, oldBestScore " + bestScore + " " + this.oldBestScore);
                EdgeListGraph edgesBest = new EdgeListGraph(this.bayesPmMnplus1.getDag());
                TetradLogger.getInstance().log("details", "Graph of model:  \n" + edgesBest);
                TetradLogger.getInstance().log("details", "====================================");
                this.oldBestScore = bestScore;
            }
        }
    }
}

