/*
 * Decompiled with CFR 0.152.
 */
package edu.cmu.tetrad.search.test;

import edu.cmu.tetrad.data.DataModel;
import edu.cmu.tetrad.data.DataSet;
import edu.cmu.tetrad.data.DiscreteVariable;
import edu.cmu.tetrad.graph.GraphUtils;
import edu.cmu.tetrad.graph.IndependenceFact;
import edu.cmu.tetrad.graph.Node;
import edu.cmu.tetrad.search.IndependenceTest;
import edu.cmu.tetrad.search.test.IndependenceResult;
import edu.cmu.tetrad.search.utils.LogUtilsSearch;
import edu.cmu.tetrad.util.RandomUtil;
import edu.cmu.tetrad.util.TetradLogger;
import edu.pitt.dbmi.algo.bayesian.constraint.inference.BCInference;
import java.util.ArrayList;
import java.util.Collections;
import java.util.HashMap;
import java.util.List;
import java.util.Map;
import java.util.Set;

public class IndTestProbabilistic
implements IndependenceTest {
    private final DataSet data;
    private final List<Node> nodes;
    private final Map<Node, Integer> indices;
    private final Map<IndependenceFact, Double> H;
    private final BCInference bci;
    private boolean threshold;
    private double posterior;
    private boolean verbose;
    private double cutoff = 0.5;
    private double priorEquivalentSampleSize = 10.0;

    public IndTestProbabilistic(DataSet dataSet) {
        if (!dataSet.isDiscrete()) {
            throw new IllegalArgumentException("Not a discrete data set.");
        }
        this.nodes = dataSet.getVariables();
        this.indices = new HashMap<Node, Integer>();
        for (int i = 0; i < this.nodes.size(); ++i) {
            this.indices.put(this.nodes.get(i), i);
        }
        this.data = dataSet;
        this.H = new HashMap<IndependenceFact, Double>();
        int[] _cols = new int[this.nodes.size()];
        for (int i = 0; i < _cols.length; ++i) {
            _cols[i] = this.indices.get(this.nodes.get(i));
        }
        int[] _rows = new int[dataSet.getNumRows()];
        for (int i = 0; i < dataSet.getNumRows(); ++i) {
            _rows[i] = i;
        }
        DataSet _data = this.data.subsetRowsColumns(_rows, _cols);
        List<Node> nodes = _data.getVariables();
        for (int i = 0; i < nodes.size(); ++i) {
            this.indices.put(nodes.get(i), i);
        }
        this.bci = this.setup(_data);
    }

    private BCInference setup(DataSet dataSet) {
        int[] nodeDimensions = new int[dataSet.getNumColumns() + 2];
        for (int j = 0; j < dataSet.getNumColumns(); ++j) {
            int numCategories;
            DiscreteVariable variable = (DiscreteVariable)dataSet.getVariable(j);
            nodeDimensions[j + 1] = numCategories = variable.getNumCategories();
        }
        int[][] cases = new int[dataSet.getNumRows() + 1][dataSet.getNumColumns() + 2];
        for (int i = 0; i < dataSet.getNumRows(); ++i) {
            for (int j = 0; j < dataSet.getNumColumns(); ++j) {
                cases[i + 1][j + 1] = dataSet.getInt(i, j) + 1;
            }
        }
        BCInference bci = new BCInference(cases, nodeDimensions);
        bci.setPriorEqivalentSampleSize(this.priorEquivalentSampleSize);
        return bci;
    }

    @Override
    public IndependenceTest indTestSubset(List<Node> vars) {
        throw new UnsupportedOperationException();
    }

    @Override
    public IndependenceResult checkIndependence(Node x, Node y, Set<Node> _z) {
        ArrayList<Node> z = new ArrayList<Node>(_z);
        Collections.sort(z);
        Node[] nodes = new Node[z.size()];
        for (int i = 0; i < z.size(); ++i) {
            nodes[i] = (Node)z.get(i);
        }
        return this.checkIndependence(x, y, nodes);
    }

    @Override
    public IndependenceResult checkIndependence(Node x, Node y, Node ... z) {
        boolean ind;
        double pInd;
        Map<Node, Integer> indices;
        BCInference bci;
        IndependenceFact key = new IndependenceFact(x, y, z);
        ArrayList<Node> allVars = new ArrayList<Node>();
        allVars.add(x);
        allVars.add(y);
        Collections.addAll(allVars, z);
        List<Integer> rows = this.getRows(this.data, allVars, this.indices);
        if (rows.isEmpty()) {
            return new IndependenceResult(new IndependenceFact(x, y, GraphUtils.asSet(z)), true, Double.NaN, Double.NaN);
        }
        if (rows.size() == this.data.getNumRows()) {
            bci = this.bci;
            indices = this.indices;
        } else {
            int[] _cols = new int[allVars.size()];
            for (int i = 0; i < _cols.length; ++i) {
                _cols[i] = this.indices.get(allVars.get(i));
            }
            int[] _rows = new int[rows.size()];
            for (int i = 0; i < rows.size(); ++i) {
                _rows[i] = rows.get(i);
            }
            DataSet _data = this.data.subsetRowsColumns(_rows, _cols);
            List<Node> nodes = _data.getVariables();
            indices = new HashMap<Node, Integer>();
            for (int i = 0; i < nodes.size(); ++i) {
                indices.put(nodes.get(i), i);
            }
            bci = this.setup(_data);
        }
        if (!this.H.containsKey(key)) {
            pInd = this.probConstraint(bci, BCInference.OP.independent, x, y, z, indices);
            this.H.put(key, pInd);
        } else {
            pInd = this.H.get(key);
        }
        double p = pInd;
        if (Double.isNaN(p)) {
            throw new RuntimeException("Undefined p-value encountered when testing " + LogUtilsSearch.independenceFact(x, y, GraphUtils.asSet(z)));
        }
        this.posterior = p;
        if (this.threshold) {
            ind = p >= this.cutoff;
        } else {
            boolean bl = ind = RandomUtil.getInstance().nextDouble() < p;
        }
        if (this.verbose && ind) {
            TetradLogger.getInstance().forceLogMessage(LogUtilsSearch.independenceFactMsg(x, y, GraphUtils.asSet(z), p));
        }
        return new IndependenceResult(new IndependenceFact(x, y, z), ind, p, Double.NaN);
    }

    public double probConstraint(BCInference bci, BCInference.OP op, Node x, Node y, Node[] z, Map<Node, Integer> indices) {
        int _x = indices.get(x) + 1;
        int _y = indices.get(y) + 1;
        int[] _z = new int[z.length + 1];
        _z[0] = z.length;
        for (int i = 0; i < z.length; ++i) {
            _z[i + 1] = indices.get(z[i]) + 1;
        }
        return bci.probConstraint(op, _x, _y, _z);
    }

    @Override
    public List<Node> getVariables() {
        return this.nodes;
    }

    @Override
    public Node getVariable(String name) {
        for (Node node : this.nodes) {
            if (!name.equals(node.getName())) continue;
            return node;
        }
        return null;
    }

    @Override
    public boolean determines(Set<Node> z, Node y) {
        throw new UnsupportedOperationException();
    }

    @Override
    public double getAlpha() {
        throw new UnsupportedOperationException("The Probabiistic Test doesn't use an alpha parameter");
    }

    @Override
    public void setAlpha(double alpha) {
        throw new UnsupportedOperationException();
    }

    @Override
    public DataModel getData() {
        return this.data;
    }

    public Map<IndependenceFact, Double> getH() {
        return new HashMap<IndependenceFact, Double>(this.H);
    }

    public double getPosterior() {
        return this.posterior;
    }

    @Override
    public boolean isVerbose() {
        return this.verbose;
    }

    @Override
    public void setVerbose(boolean verbose) {
        this.verbose = verbose;
    }

    public void setThreshold(boolean noRandomizedGeneratingConstraints) {
        this.threshold = noRandomizedGeneratingConstraints;
    }

    public void setCutoff(double cutoff) {
        this.cutoff = cutoff;
    }

    public void setPriorEquivalentSampleSize(double priorEquivalentSampleSize) {
        this.priorEquivalentSampleSize = priorEquivalentSampleSize;
    }

    private List<Integer> getRows(DataSet dataSet, List<Node> allVars, Map<Node, Integer> nodesHash) {
        ArrayList<Integer> rows = new ArrayList<Integer>();
        block0: for (int k = 0; k < dataSet.getNumRows(); ++k) {
            for (Node node : allVars) {
                if (dataSet.getInt(k, nodesHash.get(node)) != -99) continue;
                continue block0;
            }
            rows.add(k);
        }
        return rows;
    }
}

