/*
 * Decompiled with CFR 0.152.
 */
package edu.cmu.tetrad.data;

import cern.colt.list.DoubleArrayList;
import edu.cmu.tetrad.data.AndersonDarlingTest;
import edu.cmu.tetrad.data.BoxDataSet;
import edu.cmu.tetrad.data.ContinuousVariable;
import edu.cmu.tetrad.data.CorrelationMatrix;
import edu.cmu.tetrad.data.CovarianceMatrix;
import edu.cmu.tetrad.data.DataSet;
import edu.cmu.tetrad.data.DiscreteVariable;
import edu.cmu.tetrad.data.Discretizer;
import edu.cmu.tetrad.data.ICovarianceMatrix;
import edu.cmu.tetrad.data.VerticalDoubleDataBox;
import edu.cmu.tetrad.graph.Node;
import edu.cmu.tetrad.graph.NodeType;
import edu.cmu.tetrad.util.ForkJoinPoolInstance;
import edu.cmu.tetrad.util.Matrix;
import edu.cmu.tetrad.util.MatrixUtils;
import edu.cmu.tetrad.util.RandomUtil;
import edu.cmu.tetrad.util.StatUtils;
import edu.cmu.tetrad.util.Vector;
import java.rmi.MarshalledObject;
import java.util.ArrayList;
import java.util.Arrays;
import java.util.Collections;
import java.util.HashMap;
import java.util.LinkedList;
import java.util.List;
import java.util.Map;
import java.util.concurrent.ExecutorService;
import java.util.concurrent.Executors;
import java.util.concurrent.ForkJoinPool;
import java.util.concurrent.TimeUnit;
import org.apache.commons.math3.distribution.NormalDistribution;
import org.apache.commons.math3.exception.OutOfRangeException;
import org.apache.commons.math3.linear.BlockRealMatrix;
import org.apache.commons.math3.linear.RealMatrix;
import org.apache.commons.math3.random.RandomGenerator;
import org.apache.commons.math3.util.FastMath;

public final class DataUtils {
    public static void copyColumn(Node node, DataSet source, DataSet dest) {
        int sourceColumn = source.getColumn(node);
        int destColumn = dest.getColumn(node);
        if (sourceColumn < 0) {
            throw new NullPointerException("The given node was not in the source dataset");
        }
        if (destColumn < 0) {
            throw new NullPointerException("The given node was not in the destination dataset");
        }
        int sourceRows = source.getNumRows();
        int destRows = dest.getNumRows();
        if (node instanceof ContinuousVariable) {
            for (int i = 0; i < destRows && i < sourceRows; ++i) {
                dest.setDouble(i, destColumn, source.getDouble(i, sourceColumn));
            }
        } else if (node instanceof DiscreteVariable) {
            for (int i = 0; i < destRows && i < sourceRows; ++i) {
                dest.setInt(i, destColumn, source.getInt(i, sourceColumn));
            }
        } else {
            throw new IllegalArgumentException("The given variable most be discrete or continuous");
        }
    }

    public static boolean isBinary(DataSet data, int column) {
        Node node = data.getVariable(column);
        int size = data.getNumRows();
        if (node instanceof DiscreteVariable) {
            for (int i = 0; i < size; ++i) {
                int value = data.getInt(i, column);
                if (value == 1 || value == 0) continue;
                return false;
            }
        } else if (node instanceof ContinuousVariable) {
            for (int i = 0; i < size; ++i) {
                double value = data.getDouble(i, column);
                if (value == 1.0 || value == 0.0) continue;
                return false;
            }
        } else {
            throw new IllegalArgumentException("The given column is not discrete or continuous");
        }
        return true;
    }

    public static String defaultCategory(int index) {
        return Integer.toString(index);
    }

    public static DataSet addMissingData(DataSet inData, double[] probs) {
        DataSet outData = inData.copy();
        if (probs.length != outData.getNumColumns()) {
            throw new IllegalArgumentException("Wrong number of elements in prob array");
        }
        for (double prob : probs) {
            if (!(prob < 0.0) && !(prob > 1.0)) continue;
            throw new IllegalArgumentException("Probability out of range");
        }
        for (int j = 0; j < outData.getNumColumns(); ++j) {
            int i;
            Node node = outData.getVariable(j);
            if (node instanceof ContinuousVariable) {
                for (i = 0; i < outData.getNumRows(); ++i) {
                    if (!(RandomUtil.getInstance().nextDouble() < probs[j])) continue;
                    outData.setDouble(i, j, Double.NaN);
                }
                continue;
            }
            if (!(node instanceof DiscreteVariable)) continue;
            for (i = 0; i < outData.getNumRows(); ++i) {
                if (!(RandomUtil.getInstance().nextDouble() < probs[j])) continue;
                outData.setInt(i, j, -99);
            }
        }
        return outData;
    }

    public static DataSet replaceMissingWithRandom(DataSet inData) {
        DataSet outData;
        try {
            outData = new MarshalledObject<DataSet>(inData).get();
        }
        catch (Exception e) {
            throw new RuntimeException(e);
        }
        for (int j = 0; j < outData.getNumColumns(); ++j) {
            int i;
            Node variable = outData.getVariable(j);
            if (variable instanceof DiscreteVariable) {
                int value;
                int i2;
                ArrayList<Integer> values = new ArrayList<Integer>();
                for (i2 = 0; i2 < outData.getNumRows(); ++i2) {
                    value = outData.getInt(i2, j);
                    if (value == -99) continue;
                    values.add(value);
                }
                Collections.sort(values);
                for (i2 = 0; i2 < outData.getNumRows(); ++i2) {
                    if (outData.getInt(i2, j) != -99) continue;
                    value = RandomUtil.getInstance().nextInt(values.size());
                    outData.setInt(i2, j, (Integer)values.get(value));
                }
                continue;
            }
            double min = Double.POSITIVE_INFINITY;
            double max = Double.NEGATIVE_INFINITY;
            for (i = 0; i < outData.getNumRows(); ++i) {
                double value = outData.getDouble(i, j);
                if (value < min) {
                    min = value;
                }
                if (!(value > max)) continue;
                max = value;
            }
            for (i = 0; i < outData.getNumRows(); ++i) {
                double random = RandomUtil.getInstance().nextDouble();
                outData.setDouble(i, j, min + random * (max - min));
            }
        }
        return outData;
    }

    public static DataSet discreteSerializableInstance() {
        LinkedList<Node> variables = new LinkedList<Node>();
        variables.add(new DiscreteVariable("X", 2));
        BoxDataSet dataSet = new BoxDataSet(new VerticalDoubleDataBox(2, variables.size()), variables);
        dataSet.setInt(0, 0, 0);
        dataSet.setInt(1, 0, 1);
        return dataSet;
    }

    public static boolean containsMissingValue(Matrix data) {
        for (int i = 0; i < data.rows(); ++i) {
            for (int j = 0; j < data.columns(); ++j) {
                if (!Double.isNaN(data.get(i, j))) continue;
                return true;
            }
        }
        return false;
    }

    public static boolean containsMissingValue(DataSet data) {
        for (int j = 0; j < data.getNumColumns(); ++j) {
            int i;
            Node node = data.getVariable(j);
            if (node instanceof ContinuousVariable) {
                for (i = 0; i < data.getNumRows(); ++i) {
                    if (!Double.isNaN(data.getDouble(i, j))) continue;
                    return true;
                }
            }
            if (!(node instanceof DiscreteVariable)) continue;
            for (i = 0; i < data.getNumRows(); ++i) {
                if (data.getInt(i, j) != -99) continue;
                return true;
            }
        }
        return false;
    }

    public static DataSet logData(DataSet dataSet, double a, boolean isUnlog, int base) {
        Matrix data = dataSet.getDoubleData();
        Matrix X = data.like();
        for (int j = 0; j < data.columns(); ++j) {
            double[] x1Orig = Arrays.copyOf(data.getColumn(j).toArray(), data.rows());
            double[] x1 = Arrays.copyOf(data.getColumn(j).toArray(), data.rows());
            if (dataSet.getVariable(j) instanceof DiscreteVariable) {
                X.assignColumn(j, new Vector(x1));
                continue;
            }
            for (int i = 0; i < x1.length; ++i) {
                if (isUnlog) {
                    if (base == 0) {
                        x1[i] = FastMath.exp(x1Orig[i]) - a;
                        continue;
                    }
                    x1[i] = FastMath.pow((double)base, x1Orig[i]) - a;
                    continue;
                }
                double log = FastMath.log(a + x1Orig[i]);
                x1[i] = base == 0 ? log : log / FastMath.log(base);
            }
            X.assignColumn(j, new Vector(x1));
        }
        return new BoxDataSet(new VerticalDoubleDataBox(X.transpose().toArray()), dataSet.getVariables());
    }

    public static Matrix standardizeData(Matrix data) {
        Matrix data2 = data.copy();
        for (int j = 0; j < data2.columns(); ++j) {
            int i;
            double sum = 0.0;
            for (int i2 = 0; i2 < data2.rows(); ++i2) {
                sum += data2.get(i2, j);
            }
            double mean = sum / (double)data.rows();
            for (int i3 = 0; i3 < data.rows(); ++i3) {
                data2.set(i3, j, data.get(i3, j) - mean);
            }
            double norm = 0.0;
            for (i = 0; i < data.rows(); ++i) {
                double v = data2.get(i, j);
                norm += v * v;
            }
            norm = FastMath.sqrt(norm / (double)(data.rows() - 1));
            for (i = 0; i < data.rows(); ++i) {
                data2.set(i, j, data2.get(i, j) / norm);
            }
        }
        return data2;
    }

    public static double[] standardizeData(double[] data) {
        double[] data2 = new double[data.length];
        double sum = 0.0;
        for (double d : data) {
            sum += d;
        }
        double mean = sum / (double)data.length;
        for (int i = 0; i < data.length; ++i) {
            data2[i] = data[i] - mean;
        }
        double norm = 0.0;
        for (double v : data2) {
            norm += v * v;
        }
        norm = FastMath.sqrt(norm / (double)(data2.length - 1));
        for (int i = 0; i < data2.length; ++i) {
            data2[i] = data2[i] / norm;
        }
        return data2;
    }

    public static DoubleArrayList standardizeData(DoubleArrayList data) {
        int i;
        DoubleArrayList data2 = new DoubleArrayList(data.size());
        double sum = 0.0;
        for (int i2 = 0; i2 < data.size(); ++i2) {
            sum += data.get(i2);
        }
        double mean = sum / (double)data.size();
        for (int i3 = 0; i3 < data.size(); ++i3) {
            data2.add(data.get(i3) - mean);
        }
        double norm = 0.0;
        for (i = 0; i < data2.size(); ++i) {
            double v = data2.get(i);
            norm += v * v;
        }
        norm = FastMath.sqrt(norm / (double)(data2.size() - 1));
        for (i = 0; i < data2.size(); ++i) {
            data2.set(i, data2.get(i) / norm);
        }
        return data2;
    }

    public static List<DataSet> standardizeData(List<DataSet> dataSets) {
        ArrayList<DataSet> outList = new ArrayList<DataSet>();
        for (DataSet dataSet : dataSets) {
            if (!dataSet.isContinuous()) {
                throw new IllegalArgumentException("Not a continuous data set: " + dataSet.getName());
            }
            Matrix data2 = DataUtils.standardizeData(dataSet.getDoubleData());
            BoxDataSet dataSet2 = new BoxDataSet(new VerticalDoubleDataBox(data2.transpose().toArray()), dataSet.getVariables());
            outList.add(dataSet2);
        }
        return outList;
    }

    public static DataSet standardizeData(DataSet dataSet) {
        List<DataSet> dataSets = Collections.singletonList(dataSet);
        List<DataSet> outList = DataUtils.standardizeData(dataSets);
        return outList.get(0);
    }

    public static double[] center(double[] d) {
        double sum = 0.0;
        for (double v : d) {
            sum += v;
        }
        double mean = sum / (double)d.length;
        double[] d2 = new double[d.length];
        for (int i = 0; i < d.length; ++i) {
            d2[i] = d[i] - mean;
        }
        return d2;
    }

    public static Matrix centerData(Matrix data) {
        Matrix data2 = data.copy();
        for (int j = 0; j < data2.columns(); ++j) {
            double sum = 0.0;
            for (int i = 0; i < data2.rows(); ++i) {
                sum += data2.get(i, j);
            }
            double mean = sum / (double)data.rows();
            for (int i = 0; i < data.rows(); ++i) {
                data2.set(i, j, data.get(i, j) - mean);
            }
        }
        return data2;
    }

    public static List<DataSet> center(List<DataSet> dataList) {
        ArrayList<DataSet> dataSets = new ArrayList<DataSet>(dataList);
        ArrayList<DataSet> outList = new ArrayList<DataSet>();
        for (DataSet model : dataSets) {
            if (model == null) {
                throw new NullPointerException("Missing dataset.");
            }
            if (!model.isContinuous()) {
                throw new IllegalArgumentException("Not a continuous data set: " + model.getName());
            }
            Matrix data2 = DataUtils.centerData(model.getDoubleData());
            List<Node> list = model.getVariables();
            ArrayList<Node> list2 = new ArrayList<Node>(list);
            BoxDataSet dataSet2 = new BoxDataSet(new VerticalDoubleDataBox(data2.transpose().toArray()), list2);
            outList.add(dataSet2);
        }
        return outList;
    }

    public static DataSet discretize(DataSet dataSet, int numCategories, boolean variablesCopied) {
        Discretizer discretizer = new Discretizer(dataSet);
        discretizer.setVariablesCopied(variablesCopied);
        for (Node node : dataSet.getVariables()) {
            discretizer.equalIntervals(node, numCategories);
        }
        return discretizer.discretize();
    }

    public static List<Node> createContinuousVariables(String[] varNames) {
        LinkedList<Node> variables = new LinkedList<Node>();
        for (String varName : varNames) {
            variables.add(new ContinuousVariable(varName));
        }
        return variables;
    }

    public static Matrix subMatrix(ICovarianceMatrix m, Node x, Node y, List<Node> z) {
        if (x == null) {
            throw new NullPointerException();
        }
        if (y == null) {
            throw new NullPointerException();
        }
        if (z == null) {
            throw new NullPointerException();
        }
        for (Node node : z) {
            if (node != null) continue;
            throw new NullPointerException();
        }
        List<Node> variables = m.getVariables();
        int[] indices = new int[2 + z.size()];
        indices[0] = variables.indexOf(x);
        indices[1] = variables.indexOf(y);
        for (int i = 0; i < z.size(); ++i) {
            indices[i + 2] = variables.indexOf(z.get(i));
        }
        Matrix submatrix = m.getSelection(indices, indices);
        if (DataUtils.containsMissingValue(submatrix)) {
            throw new IllegalArgumentException("Please remove or impute missing values first.");
        }
        return submatrix;
    }

    public static Matrix subMatrix(Matrix m, List<Node> variables, Node x, Node y, List<Node> z) {
        if (x == null) {
            throw new NullPointerException();
        }
        if (y == null) {
            throw new NullPointerException();
        }
        if (z == null) {
            throw new NullPointerException();
        }
        for (Node node : z) {
            if (node != null) continue;
            throw new NullPointerException();
        }
        int[] indices = new int[2 + z.size()];
        indices[0] = variables.indexOf(x);
        indices[1] = variables.indexOf(y);
        for (int i = 0; i < z.size(); ++i) {
            indices[i + 2] = variables.indexOf(z.get(i));
        }
        return m.getSelection(indices, indices);
    }

    public static Matrix subMatrix(Matrix m, Map<Node, Integer> indexMap, Node x, Node y, List<Node> z) {
        if (x == null) {
            throw new NullPointerException();
        }
        if (y == null) {
            throw new NullPointerException();
        }
        if (z == null) {
            throw new NullPointerException();
        }
        for (Node node : z) {
            if (node != null) continue;
            throw new NullPointerException();
        }
        int[] indices = new int[2 + z.size()];
        indices[0] = indexMap.get(x);
        indices[1] = indexMap.get(y);
        for (int i = 0; i < z.size(); ++i) {
            indices[i + 2] = indexMap.get(z.get(i));
        }
        return m.getSelection(indices, indices);
    }

    public static Matrix subMatrix(ICovarianceMatrix m, Map<Node, Integer> indexMap, Node x, Node y, List<Node> z) {
        int[] indices = new int[2 + z.size()];
        indices[0] = indexMap.get(x);
        indices[1] = indexMap.get(y);
        for (int i = 0; i < z.size(); ++i) {
            indices[i + 2] = indexMap.get(z.get(i));
        }
        return m.getSelection(indices, indices);
    }

    public static DataSet convertNumericalDiscreteToContinuous(DataSet dataSet) throws NumberFormatException {
        ArrayList<Node> variables = new ArrayList<Node>();
        for (Node variable : dataSet.getVariables()) {
            if (variable instanceof ContinuousVariable) {
                variables.add(variable);
                continue;
            }
            variables.add(new ContinuousVariable(variable.getName()));
        }
        BoxDataSet continuousData = new BoxDataSet(new VerticalDoubleDataBox(dataSet.getNumRows(), variables.size()), variables);
        for (int j = 0; j < dataSet.getNumColumns(); ++j) {
            Node variable = dataSet.getVariable(j);
            if (variable instanceof ContinuousVariable) {
                for (int i = 0; i < dataSet.getNumRows(); ++i) {
                    continuousData.setDouble(i, j, dataSet.getDouble(i, j));
                }
                continue;
            }
            DiscreteVariable discreteVariable = (DiscreteVariable)variable;
            boolean allNumerical = true;
            for (String cat : discreteVariable.getCategories()) {
                try {
                    Double.parseDouble(cat);
                }
                catch (NumberFormatException e) {
                    allNumerical = false;
                    break;
                }
            }
            for (int i = 0; i < dataSet.getNumRows(); ++i) {
                int index = dataSet.getInt(i, j);
                String catName = discreteVariable.getCategory(index);
                double value = catName.equals("*") ? Double.NaN : (allNumerical ? Double.parseDouble(catName) : (double)index);
                continuousData.setDouble(i, j, value);
            }
        }
        return continuousData;
    }

    public static DataSet concatenate(DataSet dataSet1, DataSet dataSet2) {
        List<Node> vars1 = dataSet1.getVariables();
        List<Node> vars2 = dataSet2.getVariables();
        HashMap<String, Integer> varMap2 = new HashMap<String, Integer>();
        for (int i = 0; i < vars2.size(); ++i) {
            varMap2.put(vars2.get(i).getName(), i);
        }
        int rows1 = dataSet1.getNumRows();
        int rows2 = dataSet2.getNumRows();
        int cols1 = dataSet1.getNumColumns();
        Matrix concatMatrix = new Matrix(rows1 + rows2, cols1);
        Matrix matrix1 = dataSet1.getDoubleData();
        Matrix matrix2 = dataSet2.getDoubleData();
        for (int i = 0; i < vars1.size(); ++i) {
            int j;
            int var2 = (Integer)varMap2.get(vars1.get(i).getName());
            for (j = 0; j < rows1; ++j) {
                concatMatrix.set(j, i, matrix1.get(j, i));
            }
            for (j = 0; j < rows2; ++j) {
                concatMatrix.set(j + rows1, i, matrix2.get(j, var2));
            }
        }
        return new BoxDataSet(new VerticalDoubleDataBox(concatMatrix.transpose().toArray()), vars1);
    }

    public static DataSet concatenate(DataSet ... dataSets) {
        ArrayList<DataSet> _dataSets = new ArrayList<DataSet>();
        Collections.addAll(_dataSets, dataSets);
        return DataUtils.concatenate(_dataSets);
    }

    public static Matrix concatenate(Matrix ... dataSets) {
        int totalSampleSize = 0;
        for (Matrix dataSet : dataSets) {
            totalSampleSize += dataSet.rows();
        }
        int numColumns = dataSets[0].columns();
        Matrix allData = new Matrix(totalSampleSize, numColumns);
        int q = 0;
        for (Matrix dataSet : dataSets) {
            int r = dataSet.rows();
            for (int i = 0; i < r; ++i) {
                for (int j = 0; j < numColumns; ++j) {
                    allData.set(q + i, j, dataSet.get(i, j));
                }
            }
            q += r;
        }
        return allData;
    }

    public static DataSet concatenate(List<DataSet> dataSets) {
        int totalSampleSize = 0;
        for (DataSet dataSet : dataSets) {
            totalSampleSize += dataSet.getNumRows();
        }
        int numColumns = dataSets.get(0).getNumColumns();
        Matrix allData = new Matrix(totalSampleSize, numColumns);
        int q = 0;
        for (DataSet dataSet : dataSets) {
            Matrix _data = dataSet.getDoubleData();
            int r = _data.rows();
            for (int i = 0; i < r; ++i) {
                for (int j = 0; j < numColumns; ++j) {
                    allData.set(q + i, j, _data.get(i, j));
                }
            }
            q += r;
        }
        return new BoxDataSet(new VerticalDoubleDataBox(allData.transpose().toArray()), dataSets.get(0).getVariables());
    }

    public static DataSet restrictToMeasured(DataSet fullDataSet) {
        ArrayList<Node> measuredVars = new ArrayList<Node>();
        ArrayList<Node> latentVars = new ArrayList<Node>();
        for (Node node : fullDataSet.getVariables()) {
            if (node.getNodeType() == NodeType.MEASURED) {
                measuredVars.add(node);
                continue;
            }
            latentVars.add(node);
        }
        return latentVars.isEmpty() ? fullDataSet : fullDataSet.subsetColumns(measuredVars);
    }

    public static Vector means(Matrix data) {
        Vector means = new Vector(data.columns());
        for (int j = 0; j < means.size(); ++j) {
            double sum = 0.0;
            int count = 0;
            for (int i = 0; i < data.rows(); ++i) {
                if (Double.isNaN(data.get(i, j))) continue;
                sum += data.get(i, j);
                ++count;
            }
            double mean = sum / (double)count;
            means.set(j, mean);
        }
        return means;
    }

    public static Vector means(double[][] data) {
        Vector means = new Vector(data.length);
        int rows = data[0].length;
        for (int j = 0; j < means.size(); ++j) {
            double sum = 0.0;
            int count = 0;
            for (int i = 0; i < rows; ++i) {
                if (Double.isNaN(data[j][i])) continue;
                sum += data[j][i];
                ++count;
            }
            double mean = sum / (double)count;
            means.set(j, mean);
        }
        return means;
    }

    public static Matrix cov(Matrix data) {
        for (int j = 0; j < data.columns(); ++j) {
            double sum = 0.0;
            for (int i = 0; i < data.rows(); ++i) {
                sum += data.get(i, j);
            }
            double mean = sum / (double)data.rows();
            for (int i = 0; i < data.rows(); ++i) {
                data.set(i, j, data.get(i, j) - mean);
            }
        }
        BlockRealMatrix q = new BlockRealMatrix(data.toArray());
        RealMatrix q1 = MatrixUtils.transposeWithoutCopy(q);
        RealMatrix q2 = DataUtils.times(q1, q);
        Matrix prod = new Matrix(q2.getData());
        double factor = 1.0 / (double)(data.rows() - 1);
        for (int i = 0; i < prod.rows(); ++i) {
            for (int j = 0; j < prod.columns(); ++j) {
                prod.set(i, j, prod.get(i, j) * factor);
            }
        }
        return prod;
    }

    private static RealMatrix times(RealMatrix m, RealMatrix n) {
        if (m.getColumnDimension() != n.getRowDimension()) {
            throw new IllegalArgumentException("Incompatible matrices.");
        }
        int rowDimension = m.getRowDimension();
        int columnDimension = n.getColumnDimension();
        BlockRealMatrix out = new BlockRealMatrix(rowDimension, columnDimension);
        int NTHREADS = Runtime.getRuntime().availableProcessors();
        ForkJoinPool pool = ForkJoinPoolInstance.getInstance().getPool();
        int t = 0;
        while (t < NTHREADS) {
            int _t = t++;
            Runnable worker = () -> {
                int chunk = rowDimension / NTHREADS + 1;
                for (int row = _t * chunk; row < FastMath.min((_t + 1) * chunk, rowDimension); ++row) {
                    if ((row + 1) % 100 == 0) {
                        System.out.println(row + 1);
                    }
                    for (int col = 0; col < columnDimension; ++col) {
                        double sum = 0.0;
                        int commonDimension = m.getColumnDimension();
                        for (int i = 0; i < commonDimension; ++i) {
                            sum += m.getEntry(row, i) * n.getEntry(i, col);
                        }
                        out.setEntry(row, col, sum);
                    }
                }
            };
            pool.submit(worker);
        }
        while (!pool.isQuiescent()) {
        }
        return out;
    }

    public static Vector mean(Matrix data) {
        Vector mean = new Vector(data.columns());
        for (int i = 0; i < data.columns(); ++i) {
            mean.set(i, StatUtils.mean(data.getColumn(i).toArray()));
        }
        return mean;
    }

    public static DataSet choleskySimulation(CovarianceMatrix cov) {
        System.out.println(cov);
        int sampleSize = cov.getSampleSize();
        List<Node> variables = cov.getVariables();
        BoxDataSet dataSet = new BoxDataSet(new VerticalDoubleDataBox(sampleSize, variables.size()), variables);
        Matrix _cov = cov.getMatrix().copy();
        Matrix cholesky = MatrixUtils.cholesky(_cov);
        System.out.println("Cholesky decomposition" + cholesky);
        for (int row = 0; row < sampleSize; ++row) {
            double[] exoData = new double[cholesky.rows()];
            for (int i = 0; i < exoData.length; ++i) {
                exoData[i] = RandomUtil.getInstance().nextNormal(0.0, 1.0);
            }
            double[] point = new double[exoData.length];
            for (int i = 0; i < exoData.length; ++i) {
                double sum = 0.0;
                for (int j = 0; j <= i; ++j) {
                    sum += cholesky.get(i, j) * exoData[j];
                }
                point[i] = sum;
            }
            for (int col = 0; col < variables.size(); ++col) {
                int index = variables.indexOf(variables.get(col));
                double value = point[index];
                if (Double.isNaN(value) || Double.isInfinite(value)) {
                    System.out.println("Value out of range: " + value);
                }
                dataSet.setDouble(row, col, value);
            }
        }
        return dataSet;
    }

    public static Matrix getBootstrapSample(Matrix data, int sampleSize) {
        int actualSampleSize = data.rows();
        int[] rows = new int[sampleSize];
        for (int i = 0; i < rows.length; ++i) {
            rows[i] = RandomUtil.getInstance().nextInt(actualSampleSize);
        }
        int[] cols = new int[data.columns()];
        for (int i = 0; i < cols.length; ++i) {
            cols[i] = i;
        }
        return data.getSelection(rows, cols);
    }

    public static DataSet getResamplingDataset(DataSet data, int sampleSize) {
        int _size;
        int actualSampleSize = data.getNumRows();
        if (actualSampleSize < (_size = sampleSize)) {
            _size = actualSampleSize;
        }
        ArrayList<Integer> availRows = new ArrayList<Integer>();
        for (int i = 0; i < actualSampleSize; ++i) {
            availRows.add(i);
        }
        RandomUtil.shuffle(availRows);
        ArrayList<Integer> addedRows = new ArrayList<Integer>();
        int[] rows = new int[_size];
        for (int i = 0; i < _size; ++i) {
            int row = -1;
            int index = -1;
            while (row == -1 || addedRows.contains(row)) {
                index = RandomUtil.getInstance().nextInt(availRows.size());
                row = (Integer)availRows.get(index);
            }
            rows[i] = row;
            addedRows.add(row);
            availRows.remove(index);
        }
        int[] cols = new int[data.getNumColumns()];
        for (int i = 0; i < cols.length; ++i) {
            cols[i] = i;
        }
        return new BoxDataSet(new VerticalDoubleDataBox(data.getDoubleData().getSelection(rows, cols).transpose().toArray()), data.getVariables());
    }

    public static DataSet getResamplingDataset(DataSet data, int sampleSize, RandomGenerator randomGenerator) {
        int _size;
        int actualSampleSize = data.getNumRows();
        if (actualSampleSize < (_size = sampleSize)) {
            _size = actualSampleSize;
        }
        ArrayList<Integer> availRows = new ArrayList<Integer>();
        for (int i = 0; i < actualSampleSize; ++i) {
            availRows.add(i);
        }
        RandomUtil.shuffle(availRows);
        ArrayList<Integer> addedRows = new ArrayList<Integer>();
        int[] rows = new int[_size];
        for (int i = 0; i < _size; ++i) {
            int row = -1;
            int index = -1;
            while (row == -1 || addedRows.contains(row)) {
                index = randomGenerator.nextInt(availRows.size());
                row = (Integer)availRows.get(index);
            }
            rows[i] = row;
            addedRows.add(row);
            availRows.remove(index);
        }
        int[] cols = new int[data.getNumColumns()];
        for (int i = 0; i < cols.length; ++i) {
            cols[i] = i;
        }
        return new BoxDataSet(new VerticalDoubleDataBox(data.getDoubleData().getSelection(rows, cols).transpose().toArray()), data.getVariables());
    }

    public static DataSet getBootstrapSample(DataSet data, int sampleSize) {
        int actualSampleSize = data.getNumRows();
        int[] rows = new int[sampleSize];
        for (int i = 0; i < rows.length; ++i) {
            rows[i] = RandomUtil.getInstance().nextInt(actualSampleSize);
        }
        int[] cols = new int[data.getNumColumns()];
        for (int i = 0; i < cols.length; ++i) {
            cols[i] = i;
        }
        BoxDataSet boxDataSet = new BoxDataSet(new VerticalDoubleDataBox(data.getDoubleData().getSelection(rows, cols).transpose().toArray()), data.getVariables());
        boxDataSet.setKnowledge(data.getKnowledge());
        return boxDataSet;
    }

    public static DataSet getBootstrapSample(DataSet data, int sampleSize, RandomGenerator randomGenerator) {
        int actualSampleSize = data.getNumRows();
        int[] rows = new int[sampleSize];
        for (int i = 0; i < rows.length; ++i) {
            rows[i] = randomGenerator.nextInt(actualSampleSize);
        }
        int[] cols = new int[data.getNumColumns()];
        for (int i = 0; i < cols.length; ++i) {
            cols[i] = i;
        }
        BoxDataSet boxDataSet = new BoxDataSet(new VerticalDoubleDataBox(data.getDoubleData().getSelection(rows, cols).transpose().toArray()), data.getVariables());
        boxDataSet.setKnowledge(data.getKnowledge());
        return boxDataSet;
    }

    public static List<DataSet> split(DataSet data, double percentTest) {
        int i;
        int i2;
        if (percentTest <= 0.0 || percentTest >= 1.0) {
            throw new IllegalArgumentException();
        }
        ArrayList<Integer> rows = new ArrayList<Integer>();
        for (int i3 = 0; i3 < data.getNumRows(); ++i3) {
            rows.add(i3);
        }
        RandomUtil.shuffle(rows);
        int split = (int)((double)rows.size() * percentTest);
        ArrayList<Integer> rows1 = new ArrayList<Integer>();
        ArrayList<Integer> rows2 = new ArrayList<Integer>();
        for (i2 = 0; i2 < split; ++i2) {
            rows1.add((Integer)rows.get(i2));
        }
        for (i2 = split; i2 < rows.size(); ++i2) {
            rows2.add((Integer)rows.get(i2));
        }
        int[] _rows1 = new int[rows1.size()];
        int[] _rows2 = new int[rows2.size()];
        for (i = 0; i < rows1.size(); ++i) {
            _rows1[i] = (Integer)rows1.get(i);
        }
        for (i = 0; i < rows2.size(); ++i) {
            _rows2[i] = (Integer)rows2.get(i);
        }
        int[] cols = new int[data.getNumColumns()];
        for (int i4 = 0; i4 < cols.length; ++i4) {
            cols[i4] = i4;
        }
        BoxDataSet boxDataSet1 = new BoxDataSet(new VerticalDoubleDataBox(data.getDoubleData().getSelection(_rows1, cols).transpose().toArray()), data.getVariables());
        BoxDataSet boxDataSet2 = new BoxDataSet(new VerticalDoubleDataBox(data.getDoubleData().getSelection(_rows2, cols).transpose().toArray()), data.getVariables());
        ArrayList<DataSet> ret = new ArrayList<DataSet>();
        ret.add(boxDataSet1);
        ret.add(boxDataSet2);
        return ret;
    }

    public static DataSet center(DataSet data) {
        DataSet _data = data.copy();
        for (int j = 0; j < _data.getNumColumns(); ++j) {
            double sum = 0.0;
            int n = 0;
            for (int i = 0; i < _data.getNumRows(); ++i) {
                double v = _data.getDouble(i, j);
                if (Double.isNaN(v)) continue;
                sum += v;
                ++n;
            }
            double avg = sum / (double)n;
            for (int i = 0; i < _data.getNumRows(); ++i) {
                _data.setDouble(i, j, _data.getDouble(i, j) - avg);
            }
        }
        return _data;
    }

    public static DataSet shuffleColumns(DataSet dataModel) {
        String name = dataModel.getName();
        int numVariables = dataModel.getNumColumns();
        ArrayList<Integer> indicesList = new ArrayList<Integer>();
        for (int i = 0; i < numVariables; ++i) {
            indicesList.add(i);
        }
        RandomUtil.shuffle(indicesList);
        int[] indices = new int[numVariables];
        for (int i = 0; i < numVariables; ++i) {
            indices[i] = (Integer)indicesList.get(i);
        }
        DataSet dataSet = dataModel.subsetColumns(indices);
        dataSet.setName(name);
        return dataSet;
    }

    public static List<DataSet> shuffleColumns2(List<DataSet> dataSets) {
        ArrayList<Node> vars = new ArrayList<Node>();
        List<Node> variables = dataSets.get(0).getVariables();
        RandomUtil.shuffle(variables);
        for (Node node : variables) {
            Node _node = dataSets.get(0).getVariable(node.getName());
            if (_node == null) continue;
            vars.add(_node);
        }
        ArrayList<DataSet> ret = new ArrayList<DataSet>();
        for (DataSet m : dataSets) {
            DataSet data = m.subsetColumns(vars);
            data.setName(m.getName() + ".reordered");
            ret.add(data);
        }
        return ret;
    }

    public static ICovarianceMatrix covarianceNonparanormalDrton(DataSet dataSet) {
        CovarianceMatrix covMatrix = new CovarianceMatrix(dataSet);
        Matrix data = dataSet.getDoubleData();
        int NTHREDS = Runtime.getRuntime().availableProcessors() * 10;
        int EPOCH_COUNT = 100000;
        ExecutorService executor = Executors.newFixedThreadPool(NTHREDS);
        int runnableCount = 0;
        for (int _i = 0; _i < dataSet.getNumColumns(); ++_i) {
            for (int _j = _i; _j < dataSet.getNumColumns(); ++_j) {
                int i = _i;
                int j = _j;
                Runnable worker = () -> {
                    double tau = StatUtils.kendallsTau(data.getColumn(i).toArray(), data.getColumn(j).toArray());
                    covMatrix.setValue(i, j, tau);
                    covMatrix.setValue(j, i, tau);
                };
                executor.execute(worker);
                if (runnableCount < 100000) {
                    ++runnableCount;
                    continue;
                }
                executor.shutdown();
                try {
                    boolean b = executor.awaitTermination(Long.MAX_VALUE, TimeUnit.NANOSECONDS);
                    if (b) {
                        System.out.println("Finished all threads");
                    }
                }
                catch (InterruptedException e) {
                    e.printStackTrace();
                }
                executor = Executors.newFixedThreadPool(NTHREDS);
                runnableCount = 0;
            }
        }
        executor.shutdown();
        try {
            boolean b = executor.awaitTermination(Long.MAX_VALUE, TimeUnit.NANOSECONDS);
            if (b) {
                System.out.println("Finished all threads");
            }
        }
        catch (InterruptedException e) {
            e.printStackTrace();
        }
        return covMatrix;
    }

    public static DataSet getNonparanormalTransformed(DataSet dataSet) {
        try {
            Matrix data = dataSet.getDoubleData();
            Matrix X = data.like();
            double n = dataSet.getNumRows();
            NormalDistribution normalDistribution = new NormalDistribution();
            double std = Double.NaN;
            for (int j = 0; j < data.columns(); ++j) {
                int i;
                double[] x1Orig = Arrays.copyOf(data.getColumn(j).toArray(), data.rows());
                double[] x1 = Arrays.copyOf(data.getColumn(j).toArray(), data.rows());
                double a2Orig = new AndersonDarlingTest(x1).getASquaredStar();
                if (dataSet.getVariable(j) instanceof DiscreteVariable) {
                    X.assignColumn(j, new Vector(x1));
                    continue;
                }
                double std1 = StatUtils.sd(x1);
                double mu1 = StatUtils.mean(x1);
                double[] xTransformed = DataUtils.ranks(data, x1);
                for (i = 0; i < xTransformed.length; ++i) {
                    int n2 = i;
                    xTransformed[n2] = xTransformed[n2] / n;
                    xTransformed[i] = normalDistribution.inverseCumulativeProbability(xTransformed[i]);
                }
                if (Double.isNaN(std)) {
                    std = StatUtils.sd(x1Orig);
                }
                i = 0;
                while (i < xTransformed.length) {
                    int n3 = i;
                    xTransformed[n3] = xTransformed[n3] * std1;
                    int n4 = i++;
                    xTransformed[n4] = xTransformed[n4] + mu1;
                }
                double a2Transformed = new AndersonDarlingTest(xTransformed).getASquaredStar();
                double min = Double.POSITIVE_INFINITY;
                double max = Double.NEGATIVE_INFINITY;
                for (double v : xTransformed) {
                    if (v > max && !Double.isInfinite(v)) {
                        max = v;
                    }
                    if (!(v < min) || Double.isInfinite(v)) continue;
                    min = v;
                }
                for (int i2 = 0; i2 < xTransformed.length; ++i2) {
                    if (xTransformed[i2] == Double.POSITIVE_INFINITY) {
                        xTransformed[i2] = max;
                    }
                    if (!(xTransformed[i2] < Double.NEGATIVE_INFINITY)) continue;
                    xTransformed[i2] = min;
                }
                System.out.println(dataSet.getVariable(j) + ": A^2* = " + a2Orig + " transformed A^2* = " + a2Transformed);
                X.assignColumn(j, new Vector(xTransformed));
            }
            return new BoxDataSet(new VerticalDoubleDataBox(X.transpose().toArray()), dataSet.getVariables());
        }
        catch (OutOfRangeException e) {
            e.printStackTrace();
            return dataSet;
        }
    }

    private static double[] ranks(Matrix data, double[] x) {
        double[] ranks = new double[x.length];
        for (int i = 0; i < data.rows(); ++i) {
            double d = x[i];
            int count = 0;
            for (int k = 0; k < data.rows(); ++k) {
                if (!(x[k] <= d)) continue;
                ++count;
            }
            ranks[i] = count;
        }
        return ranks;
    }

    public static DataSet removeConstantColumns(DataSet dataSet) {
        int columns = dataSet.getNumColumns();
        int rows = dataSet.getNumRows();
        if (rows == 0) {
            return dataSet;
        }
        ArrayList<Integer> keepCols = new ArrayList<Integer>();
        for (int j = 0; j < columns; ++j) {
            Object previous = dataSet.getObject(0, j);
            boolean constant = true;
            for (int row = 1; row < rows; ++row) {
                Object current = dataSet.getObject(row, j);
                if (!previous.equals(current)) {
                    constant = false;
                    break;
                }
                if (!(previous instanceof Double) || !(current instanceof Double)) continue;
                double _previouw = (Double)previous;
                double _current = (Double)current;
                if (!Double.isNaN(_previouw) || !Double.isNaN(_current)) continue;
                constant = false;
                break;
            }
            if (constant) continue;
            keepCols.add(j);
        }
        int[] newCols = new int[keepCols.size()];
        for (int j = 0; j < keepCols.size(); ++j) {
            newCols[j] = (Integer)keepCols.get(j);
        }
        return dataSet.subsetColumns(newCols);
    }

    public static double getEss(ICovarianceMatrix covariances) {
        Matrix C = new CorrelationMatrix(covariances).getMatrix();
        double m = covariances.getSize();
        double n = covariances.getSampleSize();
        double sum = 0.0;
        for (int i = 0; i < C.rows(); ++i) {
            for (int j = 0; j < C.columns(); ++j) {
                sum += C.get(i, j);
            }
        }
        double rho = (n * sum - n * m) / (m * (n * n - n));
        return n / (1.0 + (n - 1.0) * rho);
    }
}

