/*
 * Decompiled with CFR 0.152.
 */
package edu.cmu.tetrad.data;

import edu.cmu.tetrad.data.BoxDataSet;
import edu.cmu.tetrad.data.ContinuousVariable;
import edu.cmu.tetrad.data.CorrelationMatrix;
import edu.cmu.tetrad.data.CovarianceMatrix;
import edu.cmu.tetrad.data.DataSet;
import edu.cmu.tetrad.data.DiscreteVariable;
import edu.cmu.tetrad.data.ICovarianceMatrix;
import edu.cmu.tetrad.data.VerticalDoubleDataBox;
import edu.cmu.tetrad.graph.GraphUtils;
import edu.cmu.tetrad.graph.Node;
import edu.cmu.tetrad.util.ForkJoinPoolInstance;
import edu.cmu.tetrad.util.Matrix;
import edu.cmu.tetrad.util.MatrixUtils;
import edu.cmu.tetrad.util.RandomUtil;
import edu.cmu.tetrad.util.StatUtils;
import edu.cmu.tetrad.util.SublistGenerator;
import edu.cmu.tetrad.util.Vector;
import java.util.ArrayList;
import java.util.LinkedList;
import java.util.List;
import java.util.Map;
import java.util.concurrent.ForkJoinPool;
import org.apache.commons.math3.linear.BlockRealMatrix;
import org.apache.commons.math3.linear.RealMatrix;
import org.apache.commons.math3.util.FastMath;

public final class DataUtils {
    private DataUtils() {
    }

    public static boolean isBinary(DataSet data, int column) {
        Node node = data.getVariable(column);
        int size = data.getNumRows();
        if (node instanceof DiscreteVariable) {
            for (int i = 0; i < size; ++i) {
                int value = data.getInt(i, column);
                if (value == 1 || value == 0) continue;
                return false;
            }
        } else if (node instanceof ContinuousVariable) {
            for (int i = 0; i < size; ++i) {
                double value = data.getDouble(i, column);
                if (value == 1.0 || value == 0.0) continue;
                return false;
            }
        } else {
            throw new IllegalArgumentException("The given column is not discrete or continuous");
        }
        return true;
    }

    public static String defaultCategory(int index) {
        return Integer.toString(index);
    }

    public static DataSet discreteSerializableInstance() {
        LinkedList<Node> variables = new LinkedList<Node>();
        variables.add(new DiscreteVariable("X", 2));
        BoxDataSet dataSet = new BoxDataSet(new VerticalDoubleDataBox(2, variables.size()), variables);
        dataSet.setInt(0, 0, 0);
        dataSet.setInt(1, 0, 1);
        return dataSet;
    }

    public static boolean containsMissingValue(Matrix data) {
        for (int i = 0; i < data.getNumRows(); ++i) {
            for (int j = 0; j < data.getNumColumns(); ++j) {
                if (!Double.isNaN(data.get(i, j))) continue;
                return true;
            }
        }
        return false;
    }

    public static boolean containsMissingValue(DataSet data) {
        for (int j = 0; j < data.getNumColumns(); ++j) {
            int i;
            Node node = data.getVariable(j);
            if (node instanceof ContinuousVariable) {
                for (i = 0; i < data.getNumRows(); ++i) {
                    if (!Double.isNaN(data.getDouble(i, j))) continue;
                    return true;
                }
            }
            if (!(node instanceof DiscreteVariable)) continue;
            for (i = 0; i < data.getNumRows(); ++i) {
                if (data.getInt(i, j) != -99) continue;
                return true;
            }
        }
        return false;
    }

    public static List<Node> createContinuousVariables(String[] varNames) {
        LinkedList<Node> variables = new LinkedList<Node>();
        for (String varName : varNames) {
            variables.add(new ContinuousVariable(varName));
        }
        return variables;
    }

    public static Matrix subMatrix(ICovarianceMatrix m, Node x, Node y, List<Node> z) {
        if (x == null) {
            throw new NullPointerException();
        }
        if (y == null) {
            throw new NullPointerException();
        }
        if (z == null) {
            throw new NullPointerException();
        }
        for (Node node : z) {
            if (node != null) continue;
            throw new NullPointerException();
        }
        List<Node> variables = m.getVariables();
        int[] indices = new int[2 + z.size()];
        indices[0] = variables.indexOf(x);
        indices[1] = variables.indexOf(y);
        for (int i = 0; i < z.size(); ++i) {
            indices[i + 2] = variables.indexOf(z.get(i));
        }
        Matrix submatrix = m.getSelection(indices, indices);
        if (DataUtils.containsMissingValue(submatrix)) {
            throw new IllegalArgumentException("Please remove or impute missing values first.");
        }
        return submatrix;
    }

    public static Matrix subMatrix(Matrix m, List<Node> variables, Node x, Node y, List<Node> z) {
        if (x == null) {
            throw new NullPointerException();
        }
        if (y == null) {
            throw new NullPointerException();
        }
        if (z == null) {
            throw new NullPointerException();
        }
        for (Node node : z) {
            if (node != null) continue;
            throw new NullPointerException();
        }
        int[] indices = new int[2 + z.size()];
        indices[0] = variables.indexOf(x);
        indices[1] = variables.indexOf(y);
        for (int i = 0; i < z.size(); ++i) {
            indices[i + 2] = variables.indexOf(z.get(i));
        }
        return m.getSelection(indices, indices);
    }

    public static Matrix subMatrix(Matrix m, Map<Node, Integer> indexMap, Node x, Node y, List<Node> z) {
        if (x == null) {
            throw new NullPointerException();
        }
        if (y == null) {
            throw new NullPointerException();
        }
        if (z == null) {
            throw new NullPointerException();
        }
        for (Node node : z) {
            if (node != null) continue;
            throw new NullPointerException();
        }
        int[] indices = new int[2 + z.size()];
        indices[0] = indexMap.get(x);
        indices[1] = indexMap.get(y);
        for (int i = 0; i < z.size(); ++i) {
            indices[i + 2] = indexMap.get(z.get(i));
        }
        return m.getSelection(indices, indices);
    }

    public static Matrix subMatrix(ICovarianceMatrix m, Map<Node, Integer> indexMap, Node x, Node y, List<Node> z) {
        int[] indices = new int[2 + z.size()];
        indices[0] = indexMap.get(x);
        indices[1] = indexMap.get(y);
        for (int i = 0; i < z.size(); ++i) {
            indices[i + 2] = indexMap.get(z.get(i));
        }
        return m.getSelection(indices, indices);
    }

    public static Vector means(Matrix data) {
        Vector means = new Vector(data.getNumColumns());
        for (int j = 0; j < means.size(); ++j) {
            double sum = 0.0;
            int count = 0;
            for (int i = 0; i < data.getNumRows(); ++i) {
                if (Double.isNaN(data.get(i, j))) continue;
                sum += data.get(i, j);
                ++count;
            }
            double mean = sum / (double)count;
            means.set(j, mean);
        }
        return means;
    }

    public static Vector means(double[][] data) {
        Vector means = new Vector(data.length);
        int rows = data[0].length;
        for (int j = 0; j < means.size(); ++j) {
            double sum = 0.0;
            int count = 0;
            for (int i = 0; i < rows; ++i) {
                if (Double.isNaN(data[j][i])) continue;
                sum += data[j][i];
                ++count;
            }
            double mean = sum / (double)count;
            means.set(j, mean);
        }
        return means;
    }

    public static Matrix cov(Matrix data) {
        for (int j = 0; j < data.getNumColumns(); ++j) {
            double sum = 0.0;
            for (int i = 0; i < data.getNumRows(); ++i) {
                sum += data.get(i, j);
            }
            double mean = sum / (double)data.getNumRows();
            for (int i = 0; i < data.getNumRows(); ++i) {
                data.set(i, j, data.get(i, j) - mean);
            }
        }
        BlockRealMatrix q = new BlockRealMatrix(data.toArray());
        RealMatrix q1 = MatrixUtils.transposeWithoutCopy(q);
        RealMatrix q2 = DataUtils.times(q1, q);
        Matrix prod = new Matrix(q2.getData());
        double factor = 1.0 / (double)(data.getNumRows() - 1);
        for (int i = 0; i < prod.getNumRows(); ++i) {
            for (int j = 0; j < prod.getNumColumns(); ++j) {
                prod.set(i, j, prod.get(i, j) * factor);
            }
        }
        return prod;
    }

    private static RealMatrix times(RealMatrix m, RealMatrix n) {
        if (m.getColumnDimension() != n.getRowDimension()) {
            throw new IllegalArgumentException("Incompatible matrices.");
        }
        int rowDimension = m.getRowDimension();
        int columnDimension = n.getColumnDimension();
        BlockRealMatrix out = new BlockRealMatrix(rowDimension, columnDimension);
        int NTHREADS = Runtime.getRuntime().availableProcessors();
        ForkJoinPool pool = ForkJoinPoolInstance.getInstance().getPool();
        int t = 0;
        while (t < NTHREADS) {
            int _t = t++;
            Runnable worker = () -> {
                int chunk = rowDimension / NTHREADS + 1;
                for (int row = _t * chunk; row < FastMath.min((_t + 1) * chunk, rowDimension); ++row) {
                    if ((row + 1) % 100 == 0) {
                        System.out.println(row + 1);
                    }
                    for (int col = 0; col < columnDimension; ++col) {
                        double sum = 0.0;
                        int commonDimension = m.getColumnDimension();
                        for (int i = 0; i < commonDimension; ++i) {
                            sum += m.getEntry(row, i) * n.getEntry(i, col);
                        }
                        out.setEntry(row, col, sum);
                    }
                }
            };
            pool.submit(worker);
        }
        while (!pool.isQuiescent()) {
        }
        return out;
    }

    public static Vector mean(Matrix data) {
        Vector mean = new Vector(data.getNumColumns());
        for (int i = 0; i < data.getNumColumns(); ++i) {
            mean.set(i, StatUtils.mean(data.getColumn(i).toArray()));
        }
        return mean;
    }

    public static DataSet choleskySimulation(CovarianceMatrix cov) {
        System.out.println(cov);
        int sampleSize = cov.getSampleSize();
        List<Node> variables = cov.getVariables();
        BoxDataSet dataSet = new BoxDataSet(new VerticalDoubleDataBox(sampleSize, variables.size()), variables);
        Matrix _cov = cov.getMatrix().copy();
        Matrix cholesky = MatrixUtils.cholesky(_cov);
        System.out.println("Cholesky decomposition" + cholesky);
        for (int row = 0; row < sampleSize; ++row) {
            double[] exoData = new double[cholesky.getNumRows()];
            for (int i = 0; i < exoData.length; ++i) {
                exoData[i] = RandomUtil.getInstance().nextNormal(0.0, 1.0);
            }
            double[] point = new double[exoData.length];
            for (int i = 0; i < exoData.length; ++i) {
                double sum = 0.0;
                for (int j = 0; j <= i; ++j) {
                    sum += cholesky.get(i, j) * exoData[j];
                }
                point[i] = sum;
            }
            for (int col = 0; col < variables.size(); ++col) {
                int index = variables.indexOf(variables.get(col));
                double value = point[index];
                if (Double.isNaN(value) || Double.isInfinite(value)) {
                    System.out.println("Value out of range: " + value);
                }
                dataSet.setDouble(row, col, value);
            }
        }
        return dataSet;
    }

    public static double[] ranks(double[] x) {
        int numRows = x.length;
        double[] ranks = new double[numRows];
        for (int i = 0; i < numRows; ++i) {
            double d = x[i];
            int count = 0;
            for (double v : x) {
                if (!(v <= d)) continue;
                ++count;
            }
            ranks[i] = count;
        }
        return ranks;
    }

    public static List<Node> getExampleNonsingular(ICovarianceMatrix covarianceMatrix, int depth) {
        int[] choice;
        List<Node> variables = covarianceMatrix.getVariables();
        SublistGenerator generator = new SublistGenerator(variables.size(), depth);
        while ((choice = generator.next()) != null) {
            if (choice.length < 2) continue;
            List<Node> _choice = GraphUtils.asList(choice, variables);
            ArrayList<String> names = new ArrayList<String>();
            for (Node node : _choice) {
                names.add(node.getName());
            }
            ICovarianceMatrix _dataSet = covarianceMatrix.getSubmatrix(names);
            if (!new CovarianceMatrix(_dataSet).isSingular()) continue;
            return _choice;
        }
        return null;
    }

    public static double getEss(ICovarianceMatrix covariances) {
        Matrix C = new CorrelationMatrix(covariances).getMatrix();
        double m = covariances.getSize();
        double n = covariances.getSampleSize();
        double sum = 0.0;
        for (int i = 0; i < C.getNumRows(); ++i) {
            for (int j = 0; j < C.getNumColumns(); ++j) {
                sum += C.get(i, j);
            }
        }
        double rho = (n * sum - n * m) / (m * (n * n - n));
        return n / (1.0 + (n - 1.0) * rho);
    }
}

