/*
 * Decompiled with CFR 0.152.
 */
package edu.pitt.dbmi.algo.bayesian.constraint.search;

import edu.cmu.tetrad.bayes.BayesIm;
import edu.cmu.tetrad.bayes.BayesPm;
import edu.cmu.tetrad.bayes.DirichletBayesIm;
import edu.cmu.tetrad.bayes.DirichletEstimator;
import edu.cmu.tetrad.data.BoxDataSet;
import edu.cmu.tetrad.data.DataSet;
import edu.cmu.tetrad.data.DataUtils;
import edu.cmu.tetrad.data.DiscreteVariable;
import edu.cmu.tetrad.data.SimpleDataLoader;
import edu.cmu.tetrad.data.VerticalIntDataBox;
import edu.cmu.tetrad.graph.Edge;
import edu.cmu.tetrad.graph.EdgeListGraph;
import edu.cmu.tetrad.graph.EdgeTypeProbability;
import edu.cmu.tetrad.graph.Endpoint;
import edu.cmu.tetrad.graph.Graph;
import edu.cmu.tetrad.graph.GraphUtils;
import edu.cmu.tetrad.graph.IndependenceFact;
import edu.cmu.tetrad.graph.Node;
import edu.cmu.tetrad.search.BDeuScore;
import edu.cmu.tetrad.search.Fges;
import edu.cmu.tetrad.search.GraphSearch;
import edu.cmu.tetrad.search.IndTestProbabilistic;
import edu.cmu.tetrad.search.Rfci;
import edu.cmu.tetrad.search.SearchGraphUtils;
import edu.cmu.tetrad.util.MillisecondTimes;
import edu.cmu.tetrad.util.TetradLogger;
import edu.pitt.dbmi.algo.bayesian.constraint.inference.BCInference;
import java.io.PrintStream;
import java.util.ArrayList;
import java.util.Collections;
import java.util.HashMap;
import java.util.Iterator;
import java.util.List;
import java.util.Map;
import java.util.Set;
import java.util.concurrent.Callable;
import java.util.concurrent.ConcurrentHashMap;
import java.util.concurrent.ExecutorService;
import java.util.concurrent.Executors;
import java.util.concurrent.TimeUnit;
import java.util.regex.Pattern;
import org.apache.commons.math3.util.FastMath;

public class RfciBsc
implements GraphSearch {
    private final Rfci rfci;
    private Graph graphRBD;
    private Graph graphRBI;
    private double bscD;
    private double bscI;
    private final List<Graph> pAGs = Collections.synchronizedList(new ArrayList());
    private int numRandomizedSearchModels = 10;
    private int numBscBootstrapSamples = 100;
    private double lowerBound = 0.3;
    private double upperBound = 0.7;
    private static final int MININUM_EXPONENT = -1022;
    private boolean outputRBD = true;
    private boolean verbose;
    private final TetradLogger logger = TetradLogger.getInstance();
    private PrintStream out = System.out;
    private boolean thresholdNoRandomDataSearch;
    private double cutoffDataSearch = 0.5;
    private boolean thresholdNoRandomConstrainSearch = true;
    private double cutoffConstrainSearch = 0.5;

    public RfciBsc(Rfci rfci) {
        this.rfci = rfci;
    }

    @Override
    public Graph search() {
        int trial;
        long stop = 0L;
        long start = MillisecondTimes.timeMillis();
        IndTestProbabilistic _test = (IndTestProbabilistic)this.rfci.getIndependenceTest();
        final DataSet dataSet = SimpleDataLoader.getDiscreteDataSet(_test.getData());
        this.pAGs.clear();
        final List<Node> vars = Collections.synchronizedList(new ArrayList());
        final List var_lookup = Collections.synchronizedList(new ArrayList());
        final ConcurrentHashMap h = new ConcurrentHashMap();
        final ConcurrentHashMap hCopy = new ConcurrentHashMap();
        ArrayList<Callable<Boolean>> tasks = new ArrayList<Callable<Boolean>>();
        int numCandidatePagSearchTrial = 1000;
        for (trial = 0; vars.size() == 0 && trial < numCandidatePagSearchTrial; ++trial) {
            tasks.clear();
            for (int i = 0; i < this.numRandomizedSearchModels; ++i) {
                class SearchPagTask
                implements Callable<Boolean> {
                    private final IndTestProbabilistic test;
                    private final Rfci rfci;

                    public SearchPagTask() {
                        this.test = new IndTestProbabilistic(dataSet);
                        this.test.setThreshold(RfciBsc.this.thresholdNoRandomDataSearch);
                        if (RfciBsc.this.thresholdNoRandomDataSearch) {
                            this.test.setCutoff(RfciBsc.this.cutoffDataSearch);
                        }
                        this.rfci = new Rfci(this.test);
                    }

                    @Override
                    public Boolean call() throws Exception {
                        Graph pag = this.rfci.search();
                        pag = GraphUtils.replaceNodes(pag, this.test.getVariables());
                        RfciBsc.this.pAGs.add(pag);
                        Map<IndependenceFact, Double> _h = this.test.getH();
                        for (IndependenceFact f : _h.keySet()) {
                            String indFact = f.toString();
                            if (hCopy.containsKey(f)) continue;
                            h.put(f, _h.get(f));
                            if (!(_h.get(f) > RfciBsc.this.lowerBound) || !(_h.get(f) < RfciBsc.this.upperBound)) continue;
                            hCopy.put(f, _h.get(f));
                            DiscreteVariable var = new DiscreteVariable(indFact);
                            if (vars.contains(var)) continue;
                            vars.add(var);
                            if (var_lookup.contains(indFact)) continue;
                            var_lookup.add(indFact);
                        }
                        return true;
                    }
                }
                tasks.add(new SearchPagTask());
            }
            ExecutorService pool = Executors.newWorkStealingPool(Runtime.getRuntime().availableProcessors());
            try {
                pool.invokeAll(tasks);
            }
            catch (InterruptedException exception) {
                if (this.verbose) {
                    this.logger.log("error", "Task has been interrupted");
                }
                Thread.currentThread().interrupt();
            }
            this.shutdownAndAwaitTermination(pool);
        }
        if (trial == numCandidatePagSearchTrial) {
            return new EdgeListGraph(dataSet.getVariables());
        }
        VerticalIntDataBox dataBox = new VerticalIntDataBox(this.numBscBootstrapSamples, vars.size());
        final BoxDataSet depData = new BoxDataSet(dataBox, vars);
        tasks.clear();
        int rows = dataSet.getNumRows();
        for (int b = 0; b < this.numBscBootstrapSamples; ++b) {
            class BootstrapDepDataTask
            implements Callable<Boolean> {
                private final int row_index;
                private final IndTestProbabilistic bsTest;

                public BootstrapDepDataTask(int row_index, int rows) {
                    this.row_index = row_index;
                    DataSet bsData = DataUtils.getBootstrapSample(dataSet, rows);
                    this.bsTest = new IndTestProbabilistic(bsData);
                    this.bsTest.setThreshold(RfciBsc.this.thresholdNoRandomConstrainSearch);
                    if (RfciBsc.this.thresholdNoRandomConstrainSearch) {
                        this.bsTest.setCutoff(RfciBsc.this.cutoffConstrainSearch);
                    }
                }

                /*
                 * WARNING - Removed try catching itself - possible behaviour change.
                 */
                @Override
                public Boolean call() throws Exception {
                    for (IndependenceFact f : hCopy.keySet()) {
                        boolean ind = this.bsTest.checkIndependence(f.getX(), f.getY(), f.getZ()).independent();
                        int value = ind ? 1 : 0;
                        String indFact = f.toString();
                        int col = var_lookup.indexOf(indFact);
                        DataSet dataSet2 = depData;
                        synchronized (dataSet2) {
                            depData.setInt(this.row_index, col, value);
                        }
                    }
                    return true;
                }
            }
            tasks.add(new BootstrapDepDataTask(b, rows));
        }
        ExecutorService pool = Executors.newWorkStealingPool(Runtime.getRuntime().availableProcessors());
        try {
            pool.invokeAll(tasks);
        }
        catch (InterruptedException exception) {
            if (this.verbose) {
                this.logger.log("error", "Task has been interrupted");
            }
            Thread.currentThread().interrupt();
        }
        this.shutdownAndAwaitTermination(pool);
        BDeuScore sd = new BDeuScore(depData);
        sd.setSamplePrior(1.0);
        sd.setStructurePrior(1.0);
        Fges fges = new Fges(sd);
        fges.setVerbose(false);
        fges.setFaithfulnessAssumed(true);
        Graph depPattern = fges.search();
        depPattern = GraphUtils.replaceNodes(depPattern, depData.getVariables());
        final Graph estDepBN = SearchGraphUtils.dagFromCPDAG(depPattern);
        if (this.verbose) {
            this.out.println("estDepBN:");
            this.out.println(estDepBN);
        }
        BayesPm pmHat = new BayesPm(estDepBN, 2, 2);
        DirichletBayesIm prior = DirichletBayesIm.symmetricDirichletIm(pmHat, 0.5);
        final DirichletBayesIm imHat = DirichletEstimator.estimate(prior, depData);
        final ConcurrentHashMap<Graph, Double> pagLnBSCD = new ConcurrentHashMap<Graph, Double>();
        final ConcurrentHashMap<Graph, Double> pagLnBSCI = new ConcurrentHashMap<Graph, Double>();
        double maxLnDep = -1.0;
        double maxLnInd = -1.0;
        tasks.clear();
        for (Graph pagOrig : this.pAGs) {
            class CalculateBscScoreTask
            implements Callable<Boolean> {
                final Graph pagOrig;

                public CalculateBscScoreTask(Graph pagOrig) {
                    this.pagOrig = pagOrig;
                }

                @Override
                public Boolean call() throws Exception {
                    if (!pagLnBSCD.containsKey(this.pagOrig)) {
                        double lnInd = RfciBsc.getLnProb(this.pagOrig, h);
                        double lnDep = RfciBsc.getLnProbUsingDepFiltering(this.pagOrig, h, imHat, estDepBN);
                        pagLnBSCD.put(this.pagOrig, lnDep);
                        pagLnBSCI.put(this.pagOrig, lnInd);
                    }
                    return true;
                }
            }
            tasks.add(new CalculateBscScoreTask(pagOrig));
        }
        pool = Executors.newWorkStealingPool(Runtime.getRuntime().availableProcessors());
        try {
            pool.invokeAll(tasks);
        }
        catch (InterruptedException exception) {
            if (this.verbose) {
                this.logger.log("error", "Task has been interrupted");
            }
            Thread.currentThread().interrupt();
        }
        this.shutdownAndAwaitTermination(pool);
        for (int i = 0; i < this.pAGs.size(); ++i) {
            Graph pagOrig;
            pagOrig = this.pAGs.get(i);
            double lnDep = (Double)pagLnBSCD.get(pagOrig);
            double lnInd = (Double)pagLnBSCI.get(pagOrig);
            if (lnInd > maxLnInd || i == 0) {
                maxLnInd = lnInd;
                this.graphRBI = pagOrig;
            }
            if (!(lnDep > maxLnDep) && i != 0) continue;
            maxLnDep = lnDep;
            this.graphRBD = pagOrig;
        }
        if (this.verbose) {
            this.out.println("maxLnDep: " + maxLnDep + " maxLnInd: " + maxLnInd);
        }
        double lnQBSCDTotal = RfciBsc.lnQTotal(pagLnBSCD);
        double lnQBSCITotal = RfciBsc.lnQTotal(pagLnBSCI);
        this.bscD = maxLnDep - lnQBSCDTotal;
        this.bscD = FastMath.exp(this.bscD);
        this.graphRBD.addAttribute("bscD", this.bscD);
        double _bscI = (Double)pagLnBSCI.get(this.graphRBD) - lnQBSCITotal;
        _bscI = FastMath.exp(_bscI);
        this.graphRBD.addAttribute("bscI", _bscI);
        double _bscD = (Double)pagLnBSCD.get(this.graphRBI) - lnQBSCDTotal;
        _bscD = FastMath.exp(_bscD);
        this.graphRBI.addAttribute("bscD", _bscD);
        this.bscI = maxLnInd - lnQBSCITotal;
        this.bscI = FastMath.exp(this.bscI);
        this.graphRBI.addAttribute("bscI", this.bscI);
        if (this.verbose) {
            this.out.println("bscD: " + this.bscD + " bscI: " + this.bscI);
            this.out.println("graphRBD:\n" + this.graphRBD);
            this.out.println("graphRBI:\n" + this.graphRBI);
            stop = MillisecondTimes.timeMillis();
            this.out.println("Elapsed " + (stop - start) + " ms");
        }
        Graph output = this.graphRBD;
        if (!this.outputRBD) {
            output = this.graphRBI;
        }
        return this.generateBootstrappingAttributes(output);
    }

    private Graph generateBootstrappingAttributes(Graph graph) {
        for (Edge edge : graph.getEdges()) {
            Node nodeA = edge.getNode1();
            Node nodeB = edge.getNode2();
            List<EdgeTypeProbability> edgeTypeProbabilities = this.getProbability(nodeA, nodeB);
            for (EdgeTypeProbability etp : edgeTypeProbabilities) {
                edge.addEdgeTypeProbability(etp);
            }
        }
        return graph;
    }

    private List<EdgeTypeProbability> getProbability(Node node1, Node node2) {
        HashMap<String, Integer> edgeDist = new HashMap<String, Integer>();
        int no_edge_num = 0;
        for (Graph g : this.pAGs) {
            Edge e = g.getEdge(node1, node2);
            if (e != null) {
                Integer num_edge;
                String edgeString = e.toString();
                if (e.getEndpoint1() == e.getEndpoint2() && node1.compareTo(e.getNode1()) != 0) {
                    Edge edge = new Edge(node1, node2, e.getEndpoint1(), e.getEndpoint2());
                    for (Edge.Property property : e.getProperties()) {
                        edge.addProperty(property);
                    }
                    edgeString = edge.toString();
                }
                if ((num_edge = (Integer)edgeDist.get(edgeString)) == null) {
                    num_edge = 0;
                }
                num_edge = num_edge + 1;
                edgeDist.put(edgeString, num_edge);
                continue;
            }
            ++no_edge_num;
        }
        int n = this.pAGs.size();
        ArrayList<EdgeTypeProbability> edgeTypeProbabilities = edgeDist.size() == 0 ? null : new ArrayList<EdgeTypeProbability>();
        for (String edgeString : edgeDist.keySet()) {
            Endpoint _end2;
            Endpoint _end1;
            int edge_num = (Integer)edgeDist.get(edgeString);
            double probability = (double)edge_num / (double)n;
            String[] token = edgeString.split("\\s+");
            String n1 = token[0];
            String arc = token[1];
            String n2 = token[2];
            char end1 = arc.charAt(0);
            char end2 = arc.charAt(2);
            if (end1 == '<') {
                _end1 = Endpoint.ARROW;
            } else if (end1 == 'o') {
                _end1 = Endpoint.CIRCLE;
            } else if (end1 == '-') {
                _end1 = Endpoint.TAIL;
            } else {
                throw new IllegalArgumentException();
            }
            if (end2 == '>') {
                _end2 = Endpoint.ARROW;
            } else if (end2 == 'o') {
                _end2 = Endpoint.CIRCLE;
            } else if (end2 == '-') {
                _end2 = Endpoint.TAIL;
            } else {
                throw new IllegalArgumentException();
            }
            if (node1.getName().equalsIgnoreCase(n2) && node2.getName().equalsIgnoreCase(n1)) {
                Endpoint tmp = _end1;
                _end1 = _end2;
                _end2 = tmp;
            }
            EdgeTypeProbability.EdgeType edgeType = EdgeTypeProbability.EdgeType.nil;
            if (_end1 == Endpoint.TAIL && _end2 == Endpoint.ARROW) {
                edgeType = EdgeTypeProbability.EdgeType.ta;
            }
            if (_end1 == Endpoint.ARROW && _end2 == Endpoint.TAIL) {
                edgeType = EdgeTypeProbability.EdgeType.at;
            }
            if (_end1 == Endpoint.CIRCLE && _end2 == Endpoint.ARROW) {
                edgeType = EdgeTypeProbability.EdgeType.ca;
            }
            if (_end1 == Endpoint.ARROW && _end2 == Endpoint.CIRCLE) {
                edgeType = EdgeTypeProbability.EdgeType.ac;
            }
            if (_end1 == Endpoint.CIRCLE && _end2 == Endpoint.CIRCLE) {
                edgeType = EdgeTypeProbability.EdgeType.cc;
            }
            if (_end1 == Endpoint.ARROW && _end2 == Endpoint.ARROW) {
                edgeType = EdgeTypeProbability.EdgeType.aa;
            }
            if (_end1 == Endpoint.TAIL && _end2 == Endpoint.TAIL) {
                edgeType = EdgeTypeProbability.EdgeType.tt;
            }
            EdgeTypeProbability etp = new EdgeTypeProbability(edgeType, probability);
            if (token.length > 3) {
                for (int i = 3; i < token.length; ++i) {
                    etp.addProperty(Edge.Property.valueOf(token[i]));
                }
            }
            edgeTypeProbabilities.add(etp);
        }
        if (no_edge_num < n && edgeTypeProbabilities != null) {
            edgeTypeProbabilities.add(new EdgeTypeProbability(EdgeTypeProbability.EdgeType.nil, (double)no_edge_num / (double)n));
        }
        return edgeTypeProbabilities;
    }

    private static double lnXplusY(double lnX, double lnY) {
        double lnYminusLnX;
        if (lnY > lnX) {
            double temp = lnX;
            lnX = lnY;
            lnY = temp;
        }
        if ((lnYminusLnX = lnY - lnX) < -1022.0) {
            return lnX;
        }
        double w = FastMath.log1p(FastMath.exp(lnYminusLnX));
        return w + lnX;
    }

    private static double lnQTotal(Map<Graph, Double> pagLnProb) {
        Set<Graph> pags = pagLnProb.keySet();
        Iterator<Graph> iter = pags.iterator();
        double lnQTotal = pagLnProb.get(iter.next());
        while (iter.hasNext()) {
            Graph pag = iter.next();
            double lnQ = pagLnProb.get(pag);
            lnQTotal = RfciBsc.lnXplusY(lnQTotal, lnQ);
        }
        return lnQTotal;
    }

    private static double getLnProbUsingDepFiltering(Graph pag, Map<IndependenceFact, Double> H, BayesIm im, Graph dep) {
        double lnQ = 0.0;
        for (IndependenceFact fact : H.keySet()) {
            double v;
            double p = 0.0;
            BCInference.OP op = pag.paths().isDSeparatedFrom(fact.getX(), fact.getY(), fact.getZ()) ? BCInference.OP.independent : BCInference.OP.dependent;
            if (im.getNode(fact.toString()) != null) {
                Node node = im.getNode(fact.toString());
                int[] parents = im.getParents(im.getNodeIndex(node));
                if (parents.length > 0) {
                    int[] parentValues = new int[parents.length];
                    for (int parentIndex = 0; parentIndex < parentValues.length; ++parentIndex) {
                        String parentName = im.getNode(parents[parentIndex]).getName();
                        String[] splitParent = parentName.split(Pattern.quote("_||_"));
                        Node _X = pag.getNode(splitParent[0].trim());
                        String[] splitParent2 = splitParent[1].trim().split(Pattern.quote("|"));
                        Node _Y = pag.getNode(splitParent2[0].trim());
                        ArrayList<Node> _Z = new ArrayList<Node>();
                        if (splitParent2.length > 1) {
                            String[] splitParent3;
                            for (String s : splitParent3 = splitParent2[1].trim().split(Pattern.quote(","))) {
                                _Z.add(pag.getNode(s.trim()));
                            }
                        }
                        IndependenceFact parentFact = new IndependenceFact(_X, _Y, _Z);
                        parentValues[parentIndex] = pag.paths().isDSeparatedFrom(parentFact.getX(), parentFact.getY(), parentFact.getZ()) ? 1 : 0;
                    }
                    int rowIndex = im.getRowIndex(im.getNodeIndex(node), parentValues);
                    p = im.getProbability(im.getNodeIndex(node), rowIndex, 1);
                } else {
                    p = im.getProbability(im.getNodeIndex(node), 0, 1);
                }
                if (op == BCInference.OP.dependent) {
                    p = 1.0 - p;
                }
                if (p < -1.0E-4 || p > 1.0001 || Double.isNaN(p) || Double.isInfinite(p)) {
                    throw new IllegalArgumentException("p illegally equals " + p);
                }
                double v2 = lnQ + FastMath.log(p);
                if (Double.isNaN(v2) || Double.isInfinite(v2)) continue;
                lnQ = v2;
                continue;
            }
            p = H.get(fact);
            if (p < -1.0E-4 || p > 1.0001 || Double.isNaN(p) || Double.isInfinite(p)) {
                throw new IllegalArgumentException("p illegally equals " + p);
            }
            if (op == BCInference.OP.dependent) {
                p = 1.0 - p;
            }
            if (Double.isNaN(v = lnQ + FastMath.log(p)) || Double.isInfinite(v)) continue;
            lnQ = v;
        }
        return lnQ;
    }

    private static double getLnProb(Graph pag, Map<IndependenceFact, Double> H) {
        double lnQ = 0.0;
        for (IndependenceFact fact : H.keySet()) {
            double v;
            BCInference.OP op = pag.paths().isDSeparatedFrom(fact.getX(), fact.getY(), fact.getZ()) ? BCInference.OP.independent : BCInference.OP.dependent;
            double p = H.get(fact);
            if (p < -1.0E-4 || p > 1.0001 || Double.isNaN(p) || Double.isInfinite(p)) {
                throw new IllegalArgumentException("p illegally equals " + p);
            }
            if (op == BCInference.OP.dependent) {
                p = 1.0 - p;
            }
            if (Double.isNaN(v = lnQ + FastMath.log(p)) || Double.isInfinite(v)) continue;
            lnQ = v;
        }
        return lnQ;
    }

    public void setNumRandomizedSearchModels(int numRandomizedSearchModels) {
        this.numRandomizedSearchModels = numRandomizedSearchModels;
    }

    public void setNumBscBootstrapSamples(int numBscBootstrapSamples) {
        this.numBscBootstrapSamples = numBscBootstrapSamples;
    }

    public void setLowerBound(double lowerBound) {
        this.lowerBound = lowerBound;
    }

    public void setUpperBound(double upperBound) {
        this.upperBound = upperBound;
    }

    public void setOutputRBD(boolean outputRBD) {
        this.outputRBD = outputRBD;
    }

    public Graph getGraphRBD() {
        return this.graphRBD;
    }

    public Graph getGraphRBI() {
        return this.graphRBI;
    }

    public double getBscD() {
        return this.bscD;
    }

    public double getBscI() {
        return this.bscI;
    }

    private void shutdownAndAwaitTermination(ExecutorService pool) {
        pool.shutdown();
        try {
            if (!pool.awaitTermination(1L, TimeUnit.SECONDS)) {
                pool.shutdownNow();
                if (!pool.awaitTermination(1L, TimeUnit.SECONDS)) {
                    System.err.println("Pool did not terminate");
                }
            }
        }
        catch (InterruptedException ie) {
            pool.shutdownNow();
            Thread.currentThread().interrupt();
        }
    }

    public void setVerbose(boolean verbose) {
        this.verbose = verbose;
    }

    public void setOut(PrintStream out) {
        this.out = out;
    }

    public PrintStream getOut() {
        return this.out;
    }

    public void setThresholdNoRandomDataSearch(boolean thresholdNoRandomDataSearch) {
        this.thresholdNoRandomDataSearch = thresholdNoRandomDataSearch;
    }

    public void setCutoffDataSearch(double cutoffDataSearch) {
        this.cutoffDataSearch = cutoffDataSearch;
    }

    public void setThresholdNoRandomConstrainSearch(boolean thresholdNoRandomConstrainSearch) {
        this.thresholdNoRandomConstrainSearch = thresholdNoRandomConstrainSearch;
    }

    public void setCutoffConstrainSearch(double cutoffConstrainSearch) {
        this.cutoffConstrainSearch = cutoffConstrainSearch;
    }
}

