/*
 * Decompiled with CFR 0.152.
 */
package edu.cmu.tetrad.util;

import cern.colt.matrix.DoubleMatrix1D;
import cern.colt.matrix.DoubleMatrix2D;
import cern.colt.matrix.impl.DenseDoubleMatrix2D;
import cern.colt.matrix.linalg.Algebra;
import edu.cmu.tetrad.data.ColtDataSet;
import edu.cmu.tetrad.data.DataSet;
import edu.cmu.tetrad.graph.Node;
import java.util.ArrayList;
import java.util.Arrays;
import java.util.Collections;
import java.util.List;

public final class StatUtils {
    public static double mean(long[] array) {
        return StatUtils.mean(array, array.length);
    }

    public static double mean(double[] array) {
        return StatUtils.mean(array, array.length);
    }

    public static double mean(long[] array, int N) {
        long sum = 0L;
        for (int i = 0; i < N; ++i) {
            sum += array[i];
        }
        return sum / (long)(N - 1);
    }

    public static double mean(double[] array, int N) {
        double sum = 0.0;
        for (int i = 0; i < N; ++i) {
            sum += array[i];
        }
        return sum / (double)(N - 1);
    }

    public static double mean(DoubleMatrix1D data, int N) {
        double sum = 0.0;
        for (int i = 0; i < N; ++i) {
            sum += data.get(i);
        }
        return sum / (double)(N - 1);
    }

    public static double median(long[] array) {
        return StatUtils.median(array, array.length);
    }

    public static double median(double[] array) {
        return StatUtils.median(array, array.length);
    }

    public static long median(long[] array, int N) {
        long[] a = new long[N + 1];
        System.arraycopy(array, 0, a, 0, N);
        a[N] = Long.MAX_VALUE;
        int l = 0;
        int r = N - 1;
        int k1 = r / 2;
        int k2 = r - k1;
        while (r > l) {
            long t;
            long v = a[l];
            int i = l;
            int j = r + 1;
            while (true) {
                if (a[++i] < v) {
                    continue;
                }
                while (a[--j] > v) {
                }
                if (i >= j) break;
                t = a[i];
                a[i] = a[j];
                a[j] = t;
            }
            t = a[j];
            a[j] = a[l];
            a[l] = t;
            if (j <= k1) {
                l = j + 1;
            }
            if (j < k2) continue;
            r = j - 1;
        }
        return (a[k1] + a[k2]) / 2L;
    }

    public static double median(double[] array, int N) {
        double[] a = new double[N + 1];
        System.arraycopy(array, 0, a, 0, N);
        a[N] = Double.POSITIVE_INFINITY;
        int l = 0;
        int r = N - 1;
        int k1 = r / 2;
        int k2 = r - k1;
        while (r > l) {
            double t;
            double v = a[l];
            int i = l;
            int j = r + 1;
            while (true) {
                if (a[++i] < v) {
                    continue;
                }
                while (a[--j] > v) {
                }
                if (i >= j) break;
                t = a[i];
                a[i] = a[j];
                a[j] = t;
            }
            t = a[j];
            a[j] = a[l];
            a[l] = t;
            if (j <= k1) {
                l = j + 1;
            }
            if (j < k2) continue;
            r = j - 1;
        }
        return (a[k1] + a[k2]) / 2.0;
    }

    public static double quartile(long[] array, int quartileNumber) {
        return StatUtils.quartile(array, array.length, quartileNumber);
    }

    public static double quartile(double[] array, int quartileNumber) {
        return StatUtils.quartile(array, array.length, quartileNumber);
    }

    public static double quartile(long[] array, int N, int quartileNumber) {
        if (quartileNumber < 1 || quartileNumber > 3) {
            throw new IllegalArgumentException("StatUtils.quartile:  Quartile number must be 1, 2, or 3.");
        }
        long[] a = new long[N + 1];
        System.arraycopy(array, 0, a, 0, N);
        a[N] = Long.MAX_VALUE;
        int l = 0;
        int r = N - 1;
        double doubleIndex = (double)quartileNumber / 4.0 * ((double)N + 1.0) - 1.0;
        double ratio = doubleIndex - (double)((int)doubleIndex);
        int k1 = (int)Math.floor(doubleIndex);
        int k2 = (int)Math.ceil(doubleIndex);
        while (r > l) {
            long t;
            long v = a[l];
            int i = l;
            int j = r + 1;
            while (true) {
                if (a[++i] < v) {
                    continue;
                }
                while (a[--j] > v) {
                }
                if (i >= j) break;
                t = a[i];
                a[i] = a[j];
                a[j] = t;
            }
            t = a[j];
            a[j] = a[l];
            a[l] = t;
            if (j <= k1) {
                l = j + 1;
            }
            if (j < k2) continue;
            r = j - 1;
        }
        return (double)a[k1] + ratio * (double)(a[k2] - a[k1]);
    }

    public static double quartile(double[] array, int N, int quartileNumber) {
        if (quartileNumber < 1 || quartileNumber > 3) {
            throw new IllegalArgumentException("StatUtils.quartile:  Quartile number must be 1, 2, or 3.");
        }
        double[] a = new double[N + 1];
        System.arraycopy(array, 0, a, 0, N);
        a[N] = Double.POSITIVE_INFINITY;
        int l = 0;
        int r = N - 1;
        double doubleIndex = (double)quartileNumber / 4.0 * ((double)N + 1.0) - 1.0;
        double ratio = doubleIndex - (double)((int)doubleIndex);
        int k1 = (int)Math.floor(doubleIndex);
        int k2 = (int)Math.ceil(doubleIndex);
        while (r > l) {
            double t;
            double v = a[l];
            int i = l;
            int j = r + 1;
            while (true) {
                if (a[++i] < v) {
                    continue;
                }
                while (a[--j] > v) {
                }
                if (i >= j) break;
                t = a[i];
                a[i] = a[j];
                a[j] = t;
            }
            t = a[j];
            a[j] = a[l];
            a[l] = t;
            if (j <= k1) {
                l = j + 1;
            }
            if (j < k2) continue;
            r = j - 1;
        }
        return a[k1] + ratio * (a[k2] - a[k1]);
    }

    public static double min(long[] array) {
        return StatUtils.min(array, array.length);
    }

    public static double min(double[] array) {
        return StatUtils.min(array, array.length);
    }

    public static double min(long[] array, int N) {
        double min = array[0];
        for (int i = 1; i < N; ++i) {
            if (!((double)array[i] < min)) continue;
            min = array[i];
        }
        return min;
    }

    public static double min(double[] array, int N) {
        double min = array[0];
        for (int i = 1; i < N; ++i) {
            if (!(array[i] < min)) continue;
            min = array[i];
        }
        return min;
    }

    public static double max(long[] array) {
        return StatUtils.max(array, array.length);
    }

    public static double max(double[] array) {
        return StatUtils.max(array, array.length);
    }

    public static double max(long[] array, int N) {
        double max = array[0];
        for (int i = 0; i < N; ++i) {
            if (!((double)array[i] > max)) continue;
            max = array[i];
        }
        return max;
    }

    public static double max(double[] array, int N) {
        double max = array[0];
        for (int i = 0; i < N; ++i) {
            if (!(array[i] > max)) continue;
            max = array[i];
        }
        return max;
    }

    public static double range(long[] array) {
        return StatUtils.max(array, array.length) - StatUtils.min(array, array.length);
    }

    public static double range(double[] array) {
        return StatUtils.max(array, array.length) - StatUtils.min(array, array.length);
    }

    public static double range(long[] array, int N) {
        return StatUtils.max(array, N) - StatUtils.min(array, N);
    }

    public static double range(double[] array, int N) {
        return StatUtils.max(array, N) - StatUtils.min(array, N);
    }

    public static int N(long[] array) {
        return array.length;
    }

    public static int N(double[] array) {
        return array.length;
    }

    public static double ssx(long[] array) {
        return StatUtils.ssx(array, array.length);
    }

    public static double ssx(double[] array) {
        return StatUtils.ssx(array, array.length);
    }

    public static double ssx(long[] array, int N) {
        double meanValue = StatUtils.mean(array, N);
        double sum = 0.0;
        for (int i = 0; i < N; ++i) {
            double difference = (double)array[i] - meanValue;
            sum += difference * difference;
        }
        return sum;
    }

    public static double ssx(double[] array, int N) {
        double meanValue = StatUtils.mean(array, N);
        double sum = 0.0;
        for (int i = 0; i < N; ++i) {
            double difference = array[i] - meanValue;
            sum += difference * difference;
        }
        return sum;
    }

    public static double sxy(long[] array1, long[] array2) {
        int N1 = array1.length;
        int N2 = array2.length;
        if (N1 != N2) {
            throw new IllegalArgumentException("StatUtils.SXY: Arrays passed (or lengths specified) of unequal lengths.");
        }
        return StatUtils.sxy(array1, array2, N1);
    }

    public static double sxy(double[] array1, double[] array2) {
        int N1 = array1.length;
        int N2 = array2.length;
        if (N1 != N2) {
            throw new IllegalArgumentException("StatUtils.SXY: Arrays passed (or lengths specified) of unequal lengths.");
        }
        return StatUtils.sxy(array1, array2, N1);
    }

    public static double sxy(long[] array1, long[] array2, int N) {
        double sum = 0.0;
        double meanX = StatUtils.mean(array1, N);
        double meanY = StatUtils.mean(array2, N);
        for (int i = 0; i < N; ++i) {
            sum += ((double)array1[i] - meanX) * ((double)array2[i] - meanY);
        }
        return sum;
    }

    public static double sxy(double[] array1, double[] array2, int N) {
        double sum = 0.0;
        double meanX = StatUtils.mean(array1, N);
        double meanY = StatUtils.mean(array2, N);
        for (int i = 0; i < N; ++i) {
            sum += (array1[i] - meanX) * (array2[i] - meanY);
        }
        return sum;
    }

    public static double sxy(DoubleMatrix1D data1, DoubleMatrix1D data2, int N) {
        double sum = 0.0;
        double meanX = StatUtils.mean(data1, N);
        double meanY = StatUtils.mean(data2, N);
        for (int i = 0; i < N; ++i) {
            sum += (data1.get(i) - meanX) * (data2.get(i) - meanY);
        }
        return sum;
    }

    public static double variance(long[] array) {
        return StatUtils.variance(array, array.length);
    }

    public static double variance(double[] array) {
        return StatUtils.variance(array, array.length);
    }

    public static double variance(long[] array, int N) {
        return StatUtils.ssx(array, N) / (double)(N - 1);
    }

    public static double variance(double[] array, int N) {
        return StatUtils.ssx(array, N) / (double)(N - 1);
    }

    public static double standardDeviation(long[] array) {
        return StatUtils.standardDeviation(array, array.length);
    }

    public static double standardDeviation(double[] array) {
        return StatUtils.standardDeviation(array, array.length);
    }

    public static double standardDeviation(long[] array, int N) {
        return Math.pow(StatUtils.ssx(array, N) / (double)(N - 1), 0.5);
    }

    public static double standardDeviation(double[] array, int N) {
        return Math.pow(StatUtils.ssx(array, N) / (double)(N - 1), 0.5);
    }

    public static double covariance(long[] array1, long[] array2) {
        int N1 = array1.length;
        int N2 = array2.length;
        if (N1 != N2) {
            throw new IllegalArgumentException("Arrays passed (or lengths specified) of unequal lengths.");
        }
        return StatUtils.covariance(array1, array2, N1);
    }

    public static double covariance(double[] array1, double[] array2) {
        int N1 = array1.length;
        int N2 = array2.length;
        if (N1 != N2) {
            throw new IllegalArgumentException("Arrays passed (or lengths specified) of unequal lengths.");
        }
        return StatUtils.covariance(array1, array2, N1);
    }

    public static double covariance(long[] array1, long[] array2, int N) {
        return StatUtils.sxy(array1, array2, N) / (double)(N - 1);
    }

    public static double covariance(double[] array1, double[] array2, int N) {
        return StatUtils.sxy(array1, array2, N) / (double)(N - 1);
    }

    public static double correlation(long[] array1, long[] array2) {
        int N1 = array1.length;
        int N2 = array2.length;
        if (N1 != N2) {
            throw new IllegalArgumentException("Arrays passed (or lengths specified) of unequal lengths.");
        }
        return StatUtils.correlation(array1, array2, N1);
    }

    public static double correlation(double[] array1, double[] array2) {
        int N1 = array1.length;
        int N2 = array2.length;
        if (N1 != N2) {
            throw new IllegalArgumentException("Arrays passed (or lengths specified) of unequal lengths.");
        }
        return StatUtils.correlation(array1, array2, N1);
    }

    public static double correlation(DoubleMatrix1D data1, DoubleMatrix1D data2, int N) {
        double covXY = StatUtils.sxy(data1, data2, N);
        double covXX = StatUtils.sxy(data1, data1, N);
        double covYY = StatUtils.sxy(data2, data2, N);
        return covXY / (Math.sqrt(covXX) * Math.sqrt(covYY));
    }

    public static short compressedCorrelation(DoubleMatrix1D data1, DoubleMatrix1D data2) {
        return (short)(StatUtils.correlation(data1, data2, data1.size()) * 10000.0);
    }

    public static double correlation(long[] array1, long[] array2, int N) {
        double covXY = StatUtils.sxy(array1, array2, N);
        double covXX = StatUtils.sxy(array1, array1, N);
        double covYY = StatUtils.sxy(array2, array2, N);
        return covXY / (Math.pow(covXX, 0.5) * Math.pow(covYY, 0.5));
    }

    public static double correlation(double[] array1, double[] array2, int N) {
        double covXY = StatUtils.sxy(array1, array2, N);
        double covXX = StatUtils.sxy(array1, array1, N);
        double covYY = StatUtils.sxy(array2, array2, N);
        return covXY / (Math.sqrt(covXX) * Math.sqrt(covYY));
    }

    public static double sSquare(long[] array) {
        return StatUtils.sSquare(array, array.length);
    }

    public static double sSquare(double[] array) {
        return StatUtils.ssx(array, array.length);
    }

    public static double sSquare(long[] array, int N) {
        return StatUtils.ssx(array, N) / (double)(N - 1);
    }

    public static double sSquare(double[] array, int N) {
        return StatUtils.ssx(array, N) / (double)(N - 1);
    }

    public static double varHat(long[] array) {
        return StatUtils.varHat(array, array.length);
    }

    public static double varHat(double[] array) {
        return StatUtils.varHat(array, array.length);
    }

    public static double varHat(long[] array, int N) {
        double sum = 0.0;
        double meanX = StatUtils.mean(array, N);
        for (int i = 0; i < N; ++i) {
            double difference = (double)array[i] - meanX;
            sum += difference * difference;
        }
        return sum / (double)(N - 1);
    }

    public static double varHat(double[] array, int N) {
        double sum = 0.0;
        double meanX = StatUtils.mean(array, N);
        for (int i = 0; i < N; ++i) {
            double difference = array[i] - meanX;
            sum += difference * difference;
        }
        return sum / (double)(N - 1);
    }

    public static double mu(long[] array) {
        return StatUtils.mean(array, array.length);
    }

    public static double mu(double[] array) {
        return StatUtils.mean(array, array.length);
    }

    public static double mu(long[] array, int N) {
        return StatUtils.mean(array, N);
    }

    public static double mu(double[] array, int N) {
        return StatUtils.mean(array, N);
    }

    public static double muHat(long[] array) {
        return StatUtils.muHat(array, array.length);
    }

    public static double muHat(double[] array) {
        return StatUtils.muHat(array, array.length);
    }

    public static double muHat(long[] array, int N) {
        return StatUtils.mean(array, N);
    }

    public static double muHat(double[] array, int N) {
        return StatUtils.mean(array, N);
    }

    public static double averageDeviation(long[] array) {
        return StatUtils.averageDeviation(array, array.length);
    }

    public static double averageDeviation(double[] array) {
        return StatUtils.averageDeviation(array, array.length);
    }

    public static double averageDeviation(long[] array, int N) {
        double mean = StatUtils.mean(array, N);
        double adev = 0.0;
        for (int j = 0; j < N; ++j) {
            adev += Math.abs((double)array[j] - mean);
        }
        return adev /= (double)N;
    }

    public static double averageDeviation(double[] array, int N) {
        double mean = StatUtils.mean(array, N);
        double adev = 0.0;
        for (int j = 0; j < N; ++j) {
            adev += Math.abs(array[j] - mean);
        }
        return adev /= (double)N;
    }

    public static double skew(long[] array) {
        return StatUtils.skew(array, array.length);
    }

    public static double skew(double[] array) {
        return StatUtils.skew(array, array.length);
    }

    public static double skew(long[] array, int N) {
        double mean = StatUtils.mean(array, N);
        double variance = StatUtils.variance(array, N);
        double skew = 0.0;
        for (int j = 0; j < N; ++j) {
            double s = (double)array[j] - mean;
            skew += s * s * s;
        }
        if (variance == 0.0) {
            throw new ArithmeticException("StatUtils.skew:  There is no skew when the variance is zero.");
        }
        return skew /= (double)N * Math.pow(variance, 1.5);
    }

    public static double skew(double[] array, int N) {
        double mean = StatUtils.mean(array, N);
        double variance = StatUtils.variance(array, N);
        double skew = 0.0;
        for (int j = 0; j < N; ++j) {
            double s = array[j] - mean;
            skew += s * s * s;
        }
        if (variance == 0.0) {
            throw new ArithmeticException("StatUtils.skew:  There is no skew when the variance is zero.");
        }
        return skew /= (double)N * Math.pow(variance, 1.5);
    }

    public static double kurtosis(long[] array) {
        return StatUtils.kurtosis(array, array.length);
    }

    public static double kurtosis(double[] array) {
        return StatUtils.kurtosis(array, array.length);
    }

    public static double kurtosis(long[] array, int N) {
        double mean = StatUtils.mean(array, N);
        double variance = StatUtils.variance(array, N);
        double curt = 0.0;
        for (int j = 0; j < N; ++j) {
            double s = (double)array[j] - mean;
            curt += s * s * s * s;
        }
        if (variance == 0.0) {
            throw new ArithmeticException("Curtosis is undefined when variance is zero.");
        }
        curt = curt / ((double)N * variance * variance) - 3.0;
        return curt;
    }

    public static double kurtosis(double[] array, int N) {
        double mean = StatUtils.mean(array, N);
        double variance = StatUtils.variance(array, N);
        double curt = 0.0;
        for (int j = 0; j < N; ++j) {
            double s = array[j] - mean;
            curt += s * s * s * s;
        }
        if (variance == 0.0) {
            throw new ArithmeticException("StatUtils.curtosis:  There is no curtosis when the variance is zero.");
        }
        curt = curt / ((double)N * variance * variance) - 3.0;
        return curt;
    }

    public static double gamma(double z) {
        if (z < 2.0) {
            return StatUtils.Internalgamma(z);
        }
        double multiplier = Math.floor(z / 1.2);
        double remainder = z / multiplier;
        double coef1 = Math.pow(Math.PI * 2, 0.5 * (1.0 - multiplier));
        double coef2 = Math.pow(multiplier, multiplier * remainder - 0.5);
        int N = (int)multiplier;
        double prod = 1.0;
        for (int k = 0; k < N; ++k) {
            prod *= StatUtils.Internalgamma(remainder + (double)k / multiplier);
        }
        return coef1 * coef2 * prod;
    }

    private static double Internalgamma(double z) {
        double sum = 0.0;
        double[] c = new double[]{1.0, 0.5772156649015329, -0.6558780715202538, -0.0420026350340952, 0.1665386113822915, -0.0421977345555443, -0.009621971527877, 0.007218943246663, -0.0011651675918591, -2.152416741149E-4, 1.280502823882E-4, -2.01348547807E-5, -1.2504934821E-6, 1.133027232E-6, -2.056338417E-7, 6.116095E-9, 5.0020075E-9, -1.1812746E-9, 1.043427E-10, 7.7823E-12, -3.6968E-12, 5.1E-13, -2.06E-14, -5.4E-15, 1.4E-15, 1.0E-16};
        for (int i = 0; i < c.length; ++i) {
            sum += c[i] * Math.pow(z, i + 1);
        }
        return 1.0 / sum;
    }

    public static double beta(double x1, double x2) {
        return StatUtils.gamma(x1) * StatUtils.gamma(x2) / StatUtils.gamma(x1 + x2);
    }

    public static double igamma(double a, double x) {
        double coef = Math.exp(-x) * Math.pow(x, a) / StatUtils.gamma(a);
        double sum = 0.0;
        for (int i = 0; i < 100; ++i) {
            sum += StatUtils.gamma(a) / StatUtils.gamma(a + 1.0 + (double)i) * Math.pow(x, i);
        }
        return coef * sum;
    }

    public static double erf(double x) {
        return StatUtils.igamma(0.5, Math.pow(x, 2.0));
    }

    public static double poisson(double k, double x, boolean cum) {
        if (x < 0.0 || k < 1.0) {
            throw new ArithmeticException("The Poisson Distribution Function requires x>=0 and k >= 1");
        }
        k += 1.0;
        if (cum) {
            return 1.0 - StatUtils.igamma(k, x);
        }
        return Math.exp(-x) * Math.pow(x, k) / StatUtils.gamma(k);
    }

    public static double chidist(double x, int degreesOfFreedom) {
        if (x < 0.0 || degreesOfFreedom < 0) {
            throw new ArithmeticException("The Chi Distribution Function requires x > 0.0 and degrees of freedom > 0");
        }
        return 1.0 - StatUtils.igamma((double)degreesOfFreedom / 2.0, x / 2.0);
    }

    public static double[] ContTable1(int[][] IArray) {
        int j;
        int i;
        int countTotal = 0;
        double[] retValues = new double[3];
        int counti = IArray.length;
        int countj = IArray[0].length;
        for (i = 0; i < counti; ++i) {
            for (j = 0; j < countj; ++j) {
                countTotal += IArray[i][j];
            }
        }
        double chisquare = 0.0;
        for (i = 0; i < counti; ++i) {
            for (j = 0; j < countj; ++j) {
                int looper;
                int sumiDot = 0;
                int sumjDot = 0;
                for (looper = 0; looper < countj; ++looper) {
                    sumiDot += IArray[i][looper];
                }
                for (looper = 0; looper < counti; ++looper) {
                    sumjDot += IArray[looper][j];
                }
                double expectedN = (double)(sumiDot * sumjDot) / (double)countTotal;
                chisquare += Math.pow((double)IArray[i][j] - expectedN, 2.0) / expectedN;
            }
        }
        int minCount = counti < countj ? counti - 1 : countj - 1;
        double cramerV = Math.sqrt(chisquare / (double)(countTotal * minCount));
        double contCoef = Math.sqrt(chisquare / (chisquare + (double)countTotal));
        retValues[0] = chisquare;
        retValues[1] = cramerV;
        retValues[2] = contCoef;
        return retValues;
    }

    public static int dieToss(int n) {
        return (int)Math.floor((double)n * Math.random());
    }

    public static double fdr(double alpha, double[] p, boolean negativelyCorrelated) {
        int n = p.length;
        int m = n + 1;
        for (int i = 0; i < n; ++i) {
            if (!(p[i] < 0.0) && !(p[i] > 1.0)) continue;
            throw new IllegalArgumentException("P values should be in range [0, 1]: " + p[i]);
        }
        double[] sorted = new double[m];
        System.arraycopy(p, 0, sorted, 1, n);
        Arrays.sort(sorted, 1, m);
        double c = 0.0;
        if (negativelyCorrelated) {
            for (int i = 1; i <= n; ++i) {
                c += 1.0 / (double)i;
            }
        } else {
            c = 1.0;
        }
        int k = n;
        while (k-- >= 0 && !(sorted[k] <= alpha * ((double)k / ((double)n * c)))) {
        }
        return sorted[k];
    }

    public static double partialCovariance(DoubleMatrix2D submatrix) {
        double covXy = submatrix.get(0, 1);
        int[] _z = new int[submatrix.rows() - 2];
        for (int i = 0; i < submatrix.rows() - 2; ++i) {
            _z[i] = i + 2;
        }
        DoubleMatrix2D covXz = submatrix.viewSelection(new int[]{0}, _z);
        DoubleMatrix2D covZy = submatrix.viewSelection(_z, new int[]{1}).copy();
        DoubleMatrix2D covZ = submatrix.viewSelection(_z, _z);
        Algebra algebra = new Algebra(1.0E-20);
        DoubleMatrix2D _zInverse = algebra.inverse(covZ);
        DoubleMatrix2D temp1 = algebra.mult(covXz, _zInverse);
        DoubleMatrix2D temp2 = algebra.mult(temp1, covZy);
        return covXy - temp2.get(0, 0);
    }

    public static double partialCovariance(DoubleMatrix2D covariance, int x, int y, int ... z) {
        if (x > covariance.rows()) {
            throw new IllegalArgumentException();
        }
        if (y > covariance.rows()) {
            throw new IllegalArgumentException();
        }
        for (int i = 0; i < z.length; ++i) {
            if (z[i] <= covariance.rows()) continue;
            throw new IllegalArgumentException();
        }
        int[] selection = new int[z.length + 2];
        selection[0] = x;
        selection[1] = y;
        for (int i = 0; i < z.length; ++i) {
            selection[2 + i] = z[i];
        }
        return StatUtils.partialCovariance(covariance.viewSelection(selection, selection));
    }

    public static double partialVariance(DoubleMatrix2D covariance, int x, int ... z) {
        return StatUtils.partialCovariance(covariance, x, x, z);
    }

    public static double partialStandardDeviation(DoubleMatrix2D covariance, int x, int ... z) {
        double var = StatUtils.partialVariance(covariance, x, z);
        return Math.sqrt(var);
    }

    public static double partialCorrelation(DoubleMatrix2D submatrix) {
        double cov = StatUtils.partialCovariance(submatrix);
        int[] selection1 = new int[submatrix.rows()];
        int[] selection2 = new int[submatrix.rows()];
        selection1[0] = 0;
        selection1[1] = 0;
        for (int i = 0; i < selection1.length - 2; ++i) {
            selection1[i + 2] = i + 2;
        }
        DoubleMatrix2D var1Matrix = submatrix.viewSelection(selection1, selection1);
        double var1 = StatUtils.partialCovariance(var1Matrix);
        selection2[0] = 1;
        selection2[1] = 1;
        for (int i = 0; i < selection2.length - 2; ++i) {
            selection2[i + 2] = i + 2;
        }
        DoubleMatrix2D var2Matrix = submatrix.viewSelection(selection2, selection2);
        double var2 = StatUtils.partialCovariance(var2Matrix);
        return cov / Math.sqrt(var1 * var2);
    }

    public static double partialCorrelation(DoubleMatrix2D covariance, int x, int y, int ... z) {
        if (x > covariance.rows()) {
            throw new IllegalArgumentException();
        }
        if (y > covariance.rows()) {
            throw new IllegalArgumentException();
        }
        for (int i = 0; i < z.length; ++i) {
            if (z[i] <= covariance.rows()) continue;
            throw new IllegalArgumentException();
        }
        int[] selection = new int[z.length + 2];
        selection[0] = x;
        selection[1] = y;
        for (int i = 0; i < z.length; ++i) {
            selection[2 + i] = z[i];
        }
        return StatUtils.partialCorrelation(covariance.viewSelection(selection, selection));
    }

    public class CrossValidation {
        private DataSet dataSet;
        private DataSet tDataSet;
        private DataSet vDataSet;
        private int k;
        private int curk;
        private List<Integer> chunks = new ArrayList<Integer>();

        public CrossValidation(DataSet ds, int k) {
            this.dataSet = ds;
            this.k = k;
            this.curk = 1;
            int rows = this.dataSet.getNumRows();
            for (int i = 0; i < rows; i += k) {
                for (int j = 0; j < k; ++j) {
                    if (i + j >= rows) continue;
                    this.chunks.add(j + 1);
                }
            }
            Collections.shuffle(this.chunks);
        }

        public boolean hasNext() {
            return this.curk < this.k;
        }

        public void nextIteration() {
            List<Node> vars = this.dataSet.getVariables();
            DoubleMatrix2D data = this.dataSet.getDoubleData();
            int rows = this.dataSet.getNumRows();
            int cols = this.dataSet.getNumColumns();
            int targetRows = rows / this.k;
            if (this.curk <= rows % this.k) {
                ++targetRows;
            }
            double[][] vd = new double[targetRows][cols];
            double[][] td = new double[rows - targetRows][cols];
            int v = 0;
            for (int i = 0; i < rows; ++i) {
                int j;
                if (this.chunks.get(i) == this.curk) {
                    for (j = 0; j < cols; ++j) {
                        vd[v][j] = data.get(i, j);
                    }
                    ++v;
                    continue;
                }
                for (j = 0; j < cols; ++j) {
                    td[i - v][j] = data.get(i, j);
                }
            }
            DenseDoubleMatrix2D vdata = new DenseDoubleMatrix2D(vd);
            DenseDoubleMatrix2D tdata = new DenseDoubleMatrix2D(td);
            this.tDataSet = ColtDataSet.makeContinuousData(vars, tdata);
            this.vDataSet = ColtDataSet.makeContinuousData(vars, vdata);
            ++this.curk;
        }

        public DataSet getTrainingData() {
            return this.tDataSet;
        }

        public DataSet getValidationData() {
            return this.vDataSet;
        }
    }
}

