/*
 * Decompiled with CFR 0.152.
 */
package edu.cmu.tetrad.data;

import cern.colt.matrix.DoubleMatrix1D;
import cern.colt.matrix.DoubleMatrix2D;
import cern.colt.matrix.impl.DenseDoubleMatrix1D;
import cern.colt.matrix.impl.DenseDoubleMatrix2D;
import edu.cmu.tetrad.data.AndersonDarlingTest;
import edu.cmu.tetrad.data.ColtDataSet;
import edu.cmu.tetrad.data.ContinuousDiscretizationSpec;
import edu.cmu.tetrad.data.ContinuousVariable;
import edu.cmu.tetrad.data.CovarianceMatrix;
import edu.cmu.tetrad.data.DataSet;
import edu.cmu.tetrad.data.DiscreteVariable;
import edu.cmu.tetrad.data.Discretizer;
import edu.cmu.tetrad.data.Variable;
import edu.cmu.tetrad.data.VariableSource;
import edu.cmu.tetrad.graph.Node;
import edu.cmu.tetrad.graph.NodeType;
import edu.cmu.tetrad.util.MatrixUtils;
import edu.cmu.tetrad.util.RandomUtil;
import edu.cmu.tetrad.util.StatUtils;
import java.rmi.MarshalledObject;
import java.text.DecimalFormat;
import java.util.ArrayList;
import java.util.Collections;
import java.util.HashMap;
import java.util.LinkedList;
import java.util.List;

public final class DataUtils {
    public static void copyColumn(Node node, DataSet source, DataSet dest) {
        int sourceColumn = source.getColumn(node);
        int destColumn = dest.getColumn(node);
        if (sourceColumn < 0) {
            throw new NullPointerException("The given node was not in the source dataset");
        }
        if (destColumn < 0) {
            throw new NullPointerException("The given node was not in the destination dataset");
        }
        int sourceRows = source.getNumRows();
        int destRows = dest.getNumRows();
        if (node instanceof ContinuousVariable) {
            for (int i = 0; i < destRows && i < sourceRows; ++i) {
                dest.setDouble(i, destColumn, source.getDouble(i, sourceColumn));
            }
        } else if (node instanceof DiscreteVariable) {
            for (int i = 0; i < destRows && i < sourceRows; ++i) {
                dest.setInt(i, destColumn, source.getInt(i, sourceColumn));
            }
        } else {
            throw new IllegalArgumentException("The given variable most be discrete or continuous");
        }
    }

    public static boolean isBinary(DataSet data, int column) {
        Node node = data.getVariable(column);
        int size = data.getNumRows();
        if (node instanceof DiscreteVariable) {
            for (int i = 0; i < size; ++i) {
                int value = data.getInt(i, column);
                if (value == 1 || value == 0) continue;
                return false;
            }
        } else if (node instanceof ContinuousVariable) {
            for (int i = 0; i < size; ++i) {
                double value = data.getDouble(i, column);
                if (value == 1.0 || value == 0.0) continue;
                return false;
            }
        } else {
            throw new IllegalArgumentException("The given column is not discrete or continuous");
        }
        return true;
    }

    public static void ensureVariablesExist(VariableSource source1, VariableSource source2) {
        List<Node> variablesNotFound = source1.getVariables();
        variablesNotFound.removeAll(source2.getVariables());
        if (!variablesNotFound.isEmpty()) {
            throw new IllegalArgumentException("Expected to find these variables from the given Bayes PM \nin the given discrete data set, but didn't (note: \ncategories might be different or in the wrong order): \n" + variablesNotFound);
        }
    }

    public static String defaultCategory(int index) {
        return Integer.toString(index);
    }

    public static DataSet addMissingData(DataSet inData, double[] probs) {
        DataSet outData;
        try {
            outData = new MarshalledObject<DataSet>(inData).get();
        }
        catch (Exception e) {
            throw new RuntimeException(e);
        }
        if (probs.length != outData.getNumColumns()) {
            throw new IllegalArgumentException("Wrong number of elements in prob array");
        }
        for (double prob : probs) {
            if (!(prob < 0.0) && !(prob > 1.0)) continue;
            throw new IllegalArgumentException("Probability out of range");
        }
        for (int j = 0; j < outData.getNumColumns(); ++j) {
            Node variable = outData.getVariable(j);
            for (int i = 0; i < outData.getNumRows(); ++i) {
                double test = RandomUtil.getInstance().nextDouble();
                if (!(test < probs[j])) continue;
                outData.setObject(i, j, ((Variable)variable).getMissingValueMarker());
            }
        }
        return outData;
    }

    public static DataSet continuousSerializableInstance() {
        LinkedList<Node> variables = new LinkedList<Node>();
        variables.add(new ContinuousVariable("X"));
        ColtDataSet dataSet = new ColtDataSet(10, variables);
        for (int i = 0; i < dataSet.getNumRows(); ++i) {
            for (int j = 0; j < dataSet.getNumColumns(); ++j) {
                dataSet.setDouble(i, j, RandomUtil.getInstance().nextDouble());
            }
        }
        return dataSet;
    }

    public static DataSet discreteSerializableInstance() {
        LinkedList<Node> variables = new LinkedList<Node>();
        variables.add(new DiscreteVariable("X", 2));
        ColtDataSet dataSet = new ColtDataSet(2, variables);
        dataSet.setInt(0, 0, 0);
        dataSet.setInt(1, 0, 1);
        return dataSet;
    }

    public static boolean containsMissingValue(DoubleMatrix2D data) {
        for (int i = 0; i < data.rows(); ++i) {
            for (int j = 0; j < data.columns(); ++j) {
                if (!Double.isNaN(data.getQuick(i, j))) continue;
                return true;
            }
        }
        return false;
    }

    public static boolean containsMissingValue(DataSet data) {
        for (int j = 0; j < data.getNumColumns(); ++j) {
            int i;
            Node node = data.getVariable(j);
            if (node instanceof ContinuousVariable) {
                for (i = 0; i < data.getNumRows(); ++i) {
                    if (!Double.isNaN(data.getDouble(i, j))) continue;
                    return true;
                }
            }
            if (!(node instanceof DiscreteVariable)) continue;
            for (i = 0; i < data.getNumRows(); ++i) {
                if (data.getDouble(i, j) != -99.0) continue;
                return true;
            }
        }
        return false;
    }

    public static DoubleMatrix2D standardizeData(DoubleMatrix2D data) {
        DoubleMatrix2D data2 = data.like();
        for (int j = 0; j < data.columns(); ++j) {
            int i;
            double sum = 0.0;
            for (int i2 = 0; i2 < data.rows(); ++i2) {
                sum += data.get(i2, j);
            }
            double mean = sum / (double)data.rows();
            for (int i3 = 0; i3 < data.rows(); ++i3) {
                data2.set(i3, j, data.get(i3, j) - mean);
            }
            double norm = 0.0;
            for (i = 0; i < data.rows(); ++i) {
                double v = data2.get(i, j);
                norm += v * v;
            }
            norm = Math.sqrt(norm / (double)(data.rows() - 1));
            for (i = 0; i < data.rows(); ++i) {
                data2.set(i, j, data2.get(i, j) / norm);
            }
        }
        return data2;
    }

    public static DataSet discretize(DataSet dataSet, int numCategories, boolean variablesCopied) {
        Discretizer discretizer = new Discretizer(dataSet);
        discretizer.setVariablesCopied(variablesCopied);
        for (Node node : dataSet.getVariables()) {
            discretizer.equalCounts(node, numCategories);
        }
        return discretizer.discretize();
    }

    public static ContinuousDiscretizationSpec getEqualFreqDiscretizationSpec(int numCategories, double[] data) {
        double[] breakpoints = Discretizer.getEqualFrequencyBreakPoints(data, numCategories);
        List<String> cats = DataUtils.defaultCategories(numCategories);
        return new ContinuousDiscretizationSpec(breakpoints, cats);
    }

    public static List<String> defaultCategories(int numCategories) {
        LinkedList<String> categories = new LinkedList<String>();
        for (int i = 0; i < numCategories; ++i) {
            categories.add(DataUtils.defaultCategory(i));
        }
        return categories;
    }

    public static List<Node> createContinuousVariables(String[] varNames) {
        LinkedList<Node> variables = new LinkedList<Node>();
        for (String varName : varNames) {
            variables.add(new ContinuousVariable(varName));
        }
        return variables;
    }

    public static DoubleMatrix2D subMatrix(CovarianceMatrix m, Node x, Node y, List<Node> z) {
        if (x == null) {
            throw new NullPointerException();
        }
        if (y == null) {
            throw new NullPointerException();
        }
        if (z == null) {
            throw new NullPointerException();
        }
        for (Node node : z) {
            if (node != null) continue;
            throw new NullPointerException();
        }
        List<Node> variables = m.getVariables();
        DoubleMatrix2D _covMatrix = m.getMatrix();
        int[] indices = new int[2 + z.size()];
        indices[0] = variables.indexOf(x);
        indices[1] = variables.indexOf(y);
        for (int i = 0; i < z.size(); ++i) {
            indices[i + 2] = variables.indexOf(z.get(i));
        }
        DoubleMatrix2D submatrix = _covMatrix.viewSelection(indices, indices);
        if (DataUtils.containsMissingValue(submatrix)) {
            throw new IllegalArgumentException("Please remove or impute missing values first.");
        }
        return submatrix;
    }

    public static DoubleMatrix2D subMatrix(DoubleMatrix2D m, List<Node> variables, Node x, Node y, List<Node> z) {
        if (x == null) {
            throw new NullPointerException();
        }
        if (y == null) {
            throw new NullPointerException();
        }
        if (z == null) {
            throw new NullPointerException();
        }
        for (Node node : z) {
            if (node != null) continue;
            throw new NullPointerException();
        }
        DoubleMatrix2D _covMatrix = m;
        int[] indices = new int[2 + z.size()];
        indices[0] = variables.indexOf(x);
        indices[1] = variables.indexOf(y);
        for (int i = 0; i < z.size(); ++i) {
            indices[i + 2] = variables.indexOf(z.get(i));
        }
        DoubleMatrix2D submatrix = _covMatrix.viewSelection(indices, indices);
        if (DataUtils.containsMissingValue(submatrix)) {
            throw new IllegalArgumentException("Please remove or impute missing values first.");
        }
        return submatrix;
    }

    public static DataSet shuffleColumns(DataSet dataSet) {
        int numVariables = dataSet.getNumColumns();
        ArrayList<Integer> indicesList = new ArrayList<Integer>();
        for (int i = 0; i < numVariables; ++i) {
            indicesList.add(i);
        }
        Collections.shuffle(indicesList);
        int[] indices = new int[numVariables];
        for (int i = 0; i < numVariables; ++i) {
            indices[i] = (Integer)indicesList.get(i);
        }
        return dataSet.subsetColumns(indices);
    }

    public static DataSet convertNumericalDiscreteToContinuous(DataSet dataSet) throws NumberFormatException {
        ArrayList<Node> variables = new ArrayList<Node>();
        for (Node variable : dataSet.getVariables()) {
            if (variable instanceof ContinuousVariable) {
                variables.add(variable);
                continue;
            }
            variables.add(new ContinuousVariable(variable.getName()));
        }
        ColtDataSet continuousData = new ColtDataSet(dataSet.getNumRows(), variables);
        for (int j = 0; j < dataSet.getNumColumns(); ++j) {
            Node variable = dataSet.getVariable(j);
            if (variable instanceof ContinuousVariable) {
                for (int i = 0; i < dataSet.getNumRows(); ++i) {
                    continuousData.setDouble(i, j, dataSet.getDouble(i, j));
                }
                continue;
            }
            DiscreteVariable discreteVariable = (DiscreteVariable)variable;
            for (int i = 0; i < dataSet.getNumRows(); ++i) {
                int index = dataSet.getInt(i, j);
                String catName = discreteVariable.getCategory(index);
                double value = Double.parseDouble(catName);
                continuousData.setDouble(i, j, value);
            }
        }
        return continuousData;
    }

    public static DataSet concatenateData(DataSet dataSet1, DataSet dataSet2) {
        List<Node> vars1 = dataSet1.getVariables();
        List<Node> vars2 = dataSet2.getVariables();
        HashMap<String, Integer> varMap2 = new HashMap<String, Integer>();
        for (int i = 0; i < vars2.size(); ++i) {
            varMap2.put(vars2.get(i).getName(), i);
        }
        int rows1 = dataSet1.getNumRows();
        int rows2 = dataSet2.getNumRows();
        int cols1 = dataSet1.getNumColumns();
        DenseDoubleMatrix2D concatMatrix = new DenseDoubleMatrix2D(rows1 + rows2, cols1);
        DoubleMatrix2D matrix1 = dataSet1.getDoubleData();
        DoubleMatrix2D matrix2 = dataSet2.getDoubleData();
        for (int i = 0; i < vars1.size(); ++i) {
            int j;
            int var2 = (Integer)varMap2.get(vars1.get(i).getName());
            for (j = 0; j < rows1; ++j) {
                concatMatrix.set(j, i, matrix1.get(j, i));
            }
            for (j = 0; j < rows2; ++j) {
                concatMatrix.set(j + rows1, i, matrix2.get(j, var2));
            }
        }
        ColtDataSet concatData = ColtDataSet.makeData(vars1, concatMatrix);
        return concatData;
    }

    public static DataSet concatenateDataNoChecks(List<DataSet> datasets) {
        List<Node> vars1 = datasets.get(1).getVariables();
        int cols = vars1.size();
        int rows = 0;
        for (DataSet dataset : datasets) {
            rows += dataset.getNumRows();
        }
        DenseDoubleMatrix2D concatMatrix = new DenseDoubleMatrix2D(rows, vars1.size());
        int index = 0;
        for (DataSet dataset : datasets) {
            for (int i = 0; i < dataset.getNumRows(); ++i) {
                for (int j = 0; j < cols; ++j) {
                    concatMatrix.set(index, j, dataset.getDouble(i, j));
                }
                ++index;
            }
        }
        ColtDataSet concatData = ColtDataSet.makeData(vars1, concatMatrix);
        return concatData;
    }

    public static DataSet concatenateDiscreteData(DataSet dataSet1, DataSet dataSet2) {
        List<Node> vars = dataSet1.getVariables();
        int rows1 = dataSet1.getNumRows();
        int rows2 = dataSet2.getNumRows();
        ColtDataSet concatData = new ColtDataSet(rows1 + rows2, vars);
        for (Node var : vars) {
            int var1 = dataSet1.getColumn(dataSet1.getVariable(((Object)var).toString()));
            int varc = concatData.getColumn(concatData.getVariable(((Object)var).toString()));
            for (int i = 0; i < rows1; ++i) {
                concatData.setInt(i, varc, dataSet1.getInt(i, var1));
            }
            int var2 = dataSet2.getColumn(dataSet2.getVariable(((Object)var).toString()));
            for (int i = 0; i < rows2; ++i) {
                concatData.setInt(i + rows1, varc, dataSet2.getInt(i, var2));
            }
        }
        return concatData;
    }

    public static DataSet noisyZeroes(DataSet dataSet) {
        dataSet = new ColtDataSet((ColtDataSet)dataSet);
        for (int j = 0; j < dataSet.getNumColumns(); ++j) {
            int i;
            boolean allZeroes = true;
            for (i = 0; i < dataSet.getNumRows(); ++i) {
                if (dataSet.getDouble(i, j) == 0.0) continue;
                allZeroes = false;
                break;
            }
            if (!allZeroes) continue;
            for (i = 0; i < dataSet.getNumRows(); ++i) {
                dataSet.setDouble(i, j, RandomUtil.getInstance().nextNormal(0.0, 1.0));
            }
        }
        return dataSet;
    }

    public static void printAndersonDarlingPs(DataSet dataSet) {
        System.out.println("Anderson Darling P value for Variables\n");
        DecimalFormat nf = new DecimalFormat("0.0000");
        DoubleMatrix2D m = dataSet.getDoubleData();
        for (int j = 0; j < dataSet.getNumColumns(); ++j) {
            double[] x = m.viewColumn(j).toArray();
            double p = new AndersonDarlingTest(x).getP();
            System.out.println("For " + dataSet.getVariable(j) + ", Anderson-Darling p = " + nf.format(p) + (p > 0.05 ? " = Gaussian" : " = Nongaussian"));
        }
    }

    public static DataSet restrictToMeasured(DataSet fullDataSet) {
        ArrayList<Node> measuredVars = new ArrayList<Node>();
        for (Node node : fullDataSet.getVariables()) {
            if (node.getNodeType() != NodeType.MEASURED) continue;
            measuredVars.add(node);
        }
        return fullDataSet.subsetColumns(measuredVars);
    }

    public static DoubleMatrix2D cov(DoubleMatrix2D data) {
        DenseDoubleMatrix2D cov = new DenseDoubleMatrix2D(data.columns(), data.columns());
        for (int i = 0; i < data.columns(); ++i) {
            for (int j = 0; j < data.columns(); ++j) {
                cov.set(i, j, StatUtils.covariance(data.viewColumn(i).toArray(), data.viewColumn(j).toArray()));
            }
        }
        return cov;
    }

    public static DoubleMatrix2D corr(DoubleMatrix2D data) {
        DenseDoubleMatrix2D corr = new DenseDoubleMatrix2D(data.columns(), data.columns());
        for (int i = 0; i < data.columns(); ++i) {
            for (int j = 0; j < data.columns(); ++j) {
                corr.set(i, j, StatUtils.correlation(data.viewColumn(i).toArray(), data.viewColumn(j).toArray()));
            }
        }
        return corr;
    }

    public static DoubleMatrix1D mean(DoubleMatrix2D data) {
        DenseDoubleMatrix1D mean = new DenseDoubleMatrix1D(data.columns());
        for (int i = 0; i < data.columns(); ++i) {
            mean.set(i, StatUtils.mean(data.viewColumn(i).toArray()));
        }
        return mean;
    }

    public static DataSet choleskySimulation(CovarianceMatrix cov, int sampleSize) {
        System.out.println(cov);
        List<Node> variables = cov.getVariables();
        ColtDataSet dataSet = new ColtDataSet(sampleSize, variables);
        DoubleMatrix2D _cov = cov.getMatrix().copy();
        DoubleMatrix2D cholesky = MatrixUtils.choleskyC(_cov);
        System.out.println(cholesky);
        for (int row = 0; row < sampleSize; ++row) {
            double[] exoData = new double[cholesky.rows()];
            for (int i = 0; i < exoData.length; ++i) {
                exoData[i] = RandomUtil.getInstance().nextNormal(0.0, 1.0);
            }
            double[] point = new double[exoData.length];
            for (int i = 0; i < exoData.length; ++i) {
                double sum = 0.0;
                for (int j = 0; j <= i; ++j) {
                    sum += cholesky.get(i, j) * exoData[j];
                }
                point[i] = sum;
            }
            double[] rowData = point;
            for (int col = 0; col < variables.size(); ++col) {
                int index = variables.indexOf(variables.get(col));
                double value = rowData[index];
                if (Double.isNaN(value) || Double.isInfinite(value)) {
                    throw new IllegalArgumentException("Value out of range: " + value);
                }
                dataSet.setDouble(row, col, value);
            }
        }
        return dataSet;
    }
}

