/*
 * Decompiled with CFR 0.152.
 */
package edu.cmu.tetrad.simulation;

import edu.cmu.tetrad.bayes.BayesPm;
import edu.cmu.tetrad.bayes.DirichletBayesIm;
import edu.cmu.tetrad.bayes.DirichletEstimator;
import edu.cmu.tetrad.bayes.Evidence;
import edu.cmu.tetrad.bayes.RowSummingExactUpdater;
import edu.cmu.tetrad.data.DataSet;
import edu.cmu.tetrad.graph.Dag;
import edu.cmu.tetrad.graph.Graph;
import edu.cmu.tetrad.graph.Node;
import edu.cmu.tetrad.util.RandomUtil;
import java.util.ArrayList;
import java.util.HashSet;
import java.util.Set;

public class Hsim {
    private boolean verbose;
    private Dag mydag;
    private Set<Node> simnodes;
    private DataSet data;

    public Hsim(Dag thedag, Set<Node> thesimnodes, DataSet thedata) {
        if (thedata.isContinuous()) {
            throw new IllegalArgumentException("Hsim currently only accepts discrete data.");
        }
        if (thedag == null) {
            throw new IllegalArgumentException("Hsim needs a Dag.");
        }
        if (thesimnodes == null) {
            throw new IllegalArgumentException("Please specify the nodes Hsim will resimulate.");
        }
        this.setVerbose();
        this.setDag(thedag);
        this.setData(thedata);
        this.setSimnodes(thesimnodes);
    }

    private static Set<Node> mb(Graph graph, Node z) {
        HashSet<Node> mb = new HashSet<Node>(graph.getAdjacentNodes(z));
        for (Node c : graph.getChildren(z)) {
            for (Node p : graph.getParents(c)) {
                if (p == z) continue;
                mb.add(p);
            }
        }
        return mb;
    }

    public DataSet hybridsimulate() {
        if (this.verbose) {
            System.out.println("Finding a Markov blanket for resimulated nodes");
        }
        HashSet mbAll = new HashSet();
        Set<Object> mbAdd = new HashSet();
        for (Node node : this.simnodes) {
            mbAdd = Hsim.mb(this.mydag, node);
            mbAll.addAll(mbAdd);
        }
        mbAll.addAll(this.simnodes);
        if (this.verbose) {
            System.out.println("The Markov Blanket is " + mbAll);
        }
        if (this.verbose) {
            System.out.println("Finding a subgraph over the Markov Blanket and Resimulated Nodes");
        }
        ArrayList<Node> mbListAll = new ArrayList<Node>(mbAll);
        Graph subgraph = this.mydag.subgraph(mbListAll);
        if (this.verbose) {
            System.out.println("Learning an instantiated model for the subgraph");
        }
        BayesPm subgraphPM = new BayesPm(subgraph);
        DirichletBayesIm subgraphIM = DirichletBayesIm.symmetricDirichletIm(subgraphPM, 1.0);
        DirichletEstimator estimator = new DirichletEstimator();
        DirichletBayesIm fittedsubgraphIM = DirichletEstimator.estimate(subgraphIM, this.data);
        if (this.verbose) {
            System.out.println("Starting resimulation loop");
        }
        for (int row = 0; row < this.data.getNumRows(); ++row) {
            Evidence evidence = Evidence.tautology(fittedsubgraphIM);
            for (Node node : this.simnodes) {
                mbAll.remove(node);
            }
            for (Node node : this.simnodes) {
                for (Node i : mbAll) {
                    int nodeIndex = evidence.getNodeIndex(i.getName());
                    int nodeColumn = this.data.getColumn(i);
                    evidence.getProposition().setCategory(nodeIndex, this.data.getInt(row, nodeColumn));
                }
                RowSummingExactUpdater conditionUpdate = new RowSummingExactUpdater(fittedsubgraphIM, evidence);
                int nodeIndex = evidence.getNodeIndex(node.getName());
                if (nodeIndex == -1) {
                    throw new IllegalArgumentException("Variable " + node.getName() + " was not found.");
                }
                int numCat = evidence.getNumCategories(nodeIndex);
                RandomUtil random = RandomUtil.getInstance();
                double cutoff = random.nextDouble();
                double sum = 0.0;
                int newValue = -99;
                for (int i = 0; i < numCat; ++i) {
                    double probability = conditionUpdate.getMarginal(nodeIndex, i);
                    if (!((sum += probability) >= cutoff)) continue;
                    newValue = i;
                    break;
                }
                this.data.setInt(row, this.data.getColumn(node), newValue);
                mbAll.add(node);
            }
        }
        return this.data;
    }

    private void setVerbose() {
        this.verbose = false;
    }

    private void setDag(Dag thedag) {
        this.mydag = thedag;
    }

    private void setSimnodes(Set<Node> thenodes) {
        this.simnodes = thenodes;
    }

    private void setData(DataSet thedata) {
        this.data = thedata;
    }
}

