/*
 * Decompiled with CFR 0.152.
 */
package edu.cmu.tetrad.data;

import cern.colt.matrix.DoubleMatrix2D;
import edu.cmu.tetrad.data.BoxDataSet;
import edu.cmu.tetrad.data.ContinuousVariable;
import edu.cmu.tetrad.data.CovarianceMatrix;
import edu.cmu.tetrad.data.DataBox;
import edu.cmu.tetrad.data.DataModel;
import edu.cmu.tetrad.data.DataSet;
import edu.cmu.tetrad.data.DataUtils;
import edu.cmu.tetrad.data.DoubleDataBox;
import edu.cmu.tetrad.data.ICovarianceMatrix;
import edu.cmu.tetrad.data.Knowledge;
import edu.cmu.tetrad.data.VerticalDoubleDataBox;
import edu.cmu.tetrad.graph.Node;
import edu.cmu.tetrad.util.ForkJoinPoolInstance;
import edu.cmu.tetrad.util.Matrix;
import edu.cmu.tetrad.util.NumberFormatUtil;
import edu.cmu.tetrad.util.Vector;
import java.io.IOException;
import java.io.ObjectInputStream;
import java.text.NumberFormat;
import java.util.ArrayList;
import java.util.Arrays;
import java.util.Collections;
import java.util.HashSet;
import java.util.LinkedList;
import java.util.List;
import java.util.Set;
import java.util.concurrent.RecursiveTask;
import org.apache.commons.math3.util.FastMath;

public class CovarianceMatrixOnTheFly
implements ICovarianceMatrix {
    static final long serialVersionUID = 23L;
    private boolean verbose = false;
    private String name;
    private List<Node> variables;
    private int sampleSize;
    private Matrix matrix;
    private DoubleMatrix2D matrixC;
    private Set<Node> selectedVariables = new HashSet<Node>();
    private Knowledge knowledge = new Knowledge();
    private double[][] vectors = null;
    private final double[] variances;

    public CovarianceMatrixOnTheFly(DataSet dataSet) {
        this(dataSet, false);
    }

    public CovarianceMatrixOnTheFly(DataSet dataSet, boolean verbose) {
        Vector means;
        if (!dataSet.isContinuous()) {
            throw new IllegalArgumentException("Not a continuous data set.");
        }
        this.variables = Collections.unmodifiableList(dataSet.getVariables());
        this.sampleSize = dataSet.getNumRows();
        if (verbose) {
            System.out.println("Calculating variable vectors");
        }
        if (dataSet instanceof BoxDataSet) {
            DataBox box = ((BoxDataSet)dataSet).getDataBox();
            if (box instanceof VerticalDoubleDataBox) {
                if (verbose) {
                    System.out.println("Getting vectors from VerticalDoubleDataBox");
                }
                if (!dataSet.getVariables().equals(this.variables)) {
                    throw new IllegalArgumentException();
                }
                this.vectors = ((VerticalDoubleDataBox)box).getVariableVectors();
                if (verbose) {
                    System.out.println("Calculating means");
                }
                means = DataUtils.means(this.vectors);
                CovarianceMatrixOnTheFly.demean(this.vectors, means);
            } else if (box instanceof DoubleDataBox) {
                if (verbose) {
                    System.out.println("Getting vectors from DoubleDataBox");
                }
                if (!dataSet.getVariables().equals(this.variables)) {
                    throw new IllegalArgumentException();
                }
                double[][] horizData = ((DoubleDataBox)box).getData();
                if (verbose) {
                    System.out.println("Transposing data");
                }
                this.vectors = new double[horizData[0].length][horizData.length];
                for (int i = 0; i < horizData.length; ++i) {
                    for (int j = 0; j < horizData[0].length; ++j) {
                        this.vectors[j][i] = horizData[i][j];
                    }
                }
                if (verbose) {
                    System.out.println("Calculating means");
                }
                Vector means2 = DataUtils.means(this.vectors);
                CovarianceMatrixOnTheFly.demean(this.vectors, means2);
            }
        }
        if (this.vectors == null) {
            if (verbose) {
                System.out.println("Copying data");
            }
            Matrix doubleData = dataSet.getDoubleData().copy();
            if (verbose) {
                System.out.println("Calculating means");
            }
            means = DataUtils.means(doubleData);
            if (verbose) {
                System.out.println("Demeaning");
            }
            CovarianceMatrixOnTheFly.demean(this.vectors, means);
            if (verbose) {
                System.out.println("Getting vectors from data");
            }
            this.vectors = new double[this.variables.size()][];
            for (int i = 0; i < this.variables.size(); ++i) {
                this.vectors[i] = this.matrix.getColumn(i).toArray();
            }
        }
        if (verbose) {
            System.out.println("Calculating variances");
        }
        this.variances = new double[this.variables.size()];
        int NTHREADS = Runtime.getRuntime().availableProcessors() * 10;
        int _chunk = this.variables.size() / NTHREADS + 1;
        int minChunk = 100;
        int chunk = FastMath.max(_chunk, minChunk);
        class VarianceTask
        extends RecursiveTask<Boolean> {
            private final int chunk;
            private final int from;
            private final int to;

            public VarianceTask(int chunk, int from, int to) {
                this.chunk = chunk;
                this.from = from;
                this.to = to;
            }

            @Override
            protected Boolean compute() {
                if (this.to - this.from <= this.chunk) {
                    for (int i = this.from; i < this.to; ++i) {
                        double d = 0.0;
                        int count = 0;
                        double[] v1 = CovarianceMatrixOnTheFly.this.vectors[i];
                        for (int k = 0; k < CovarianceMatrixOnTheFly.this.sampleSize; ++k) {
                            if (Double.isNaN(v1[k])) continue;
                            d += v1[k] * v1[k];
                            ++count;
                        }
                        double v = d;
                        ((CovarianceMatrixOnTheFly)CovarianceMatrixOnTheFly.this).variances[i] = v /= (double)(count - 1);
                        if (v != 0.0) continue;
                        System.out.println("Zero variance! " + CovarianceMatrixOnTheFly.this.variables.get(i));
                    }
                    return true;
                }
                int numIntervals = 4;
                int step = (this.to - this.from) / 4 + 1;
                ArrayList<VarianceTask> tasks = new ArrayList<VarianceTask>();
                for (int i = 0; i < 4; ++i) {
                    VarianceTask task = new VarianceTask(this.chunk, this.from + i * step, FastMath.min(this.from + (i + 1) * step, this.to));
                    tasks.add(task);
                }
                VarianceTask.invokeAll(tasks);
                return true;
            }
        }
        VarianceTask task = new VarianceTask(chunk, 0, this.variables.size());
        ForkJoinPoolInstance.getInstance().getPool().invoke(task);
        if (verbose) {
            System.out.println("Done with variances.");
        }
    }

    public static void demean(double[][] data, Vector means) {
        for (int j = 0; j < data.length; ++j) {
            int i = 0;
            while (i < data[j].length) {
                double[] dArray = data[j];
                int n = i++;
                dArray[n] = dArray[n] - means.get(j);
            }
        }
    }

    public static ICovarianceMatrix serializableInstance() {
        ArrayList<Node> variables = new ArrayList<Node>();
        ContinuousVariable x = new ContinuousVariable("X");
        variables.add(x);
        Matrix matrix = Matrix.identity(1);
        return new CovarianceMatrix(variables, matrix, 100);
    }

    @Override
    public final List<Node> getVariables() {
        return this.variables;
    }

    @Override
    public final List<String> getVariableNames() {
        ArrayList<String> names = new ArrayList<String>();
        for (int i = 0; i < this.getVariables().size(); ++i) {
            Node variable = this.getVariables().get(i);
            names.add(variable.getName());
        }
        return names;
    }

    @Override
    public final String getVariableName(int index) {
        if (index >= this.getVariables().size()) {
            throw new IllegalArgumentException("Index out of range: " + index);
        }
        Node variable = this.getVariables().get(index);
        return variable.getName();
    }

    @Override
    public final int getDimension() {
        return this.variables.size();
    }

    @Override
    public final int getSampleSize() {
        return this.sampleSize;
    }

    @Override
    public final String getName() {
        return this.name;
    }

    @Override
    public final void setName(String name) {
        this.name = name;
    }

    @Override
    public final Knowledge getKnowledge() {
        return this.knowledge.copy();
    }

    @Override
    public final void setKnowledge(Knowledge knowledge) {
        if (knowledge == null) {
            throw new NullPointerException();
        }
        this.knowledge = knowledge.copy();
    }

    @Override
    public final ICovarianceMatrix getSubmatrix(int[] indices) {
        LinkedList<Node> submatrixVars = new LinkedList<Node>();
        for (int indice : indices) {
            submatrixVars.add(this.variables.get(indice));
        }
        Matrix cov = new Matrix(indices.length, indices.length);
        for (int i = 0; i < indices.length; ++i) {
            for (int j = i; j < indices.length; ++j) {
                double d = this.getValue(indices[i], indices[j]);
                cov.set(i, j, d);
                cov.set(j, i, d);
            }
        }
        return new CovarianceMatrix(submatrixVars, cov, this.getSampleSize());
    }

    public final ICovarianceMatrix getSubmatrix(int[] indices, int[] dataRows) {
        LinkedList<Node> submatrixVars = new LinkedList<Node>();
        for (int indice : indices) {
            submatrixVars.add(this.variables.get(indice));
        }
        Matrix cov = new Matrix(indices.length, indices.length);
        for (int i = 0; i < indices.length; ++i) {
            for (int j = i; j < indices.length; ++j) {
                double d = this.getValue(indices[i], indices[j], dataRows);
                cov.set(i, j, d);
                cov.set(j, i, d);
            }
        }
        return new CovarianceMatrix(submatrixVars, cov, this.getSampleSize());
    }

    @Override
    public final ICovarianceMatrix getSubmatrix(List<String> submatrixVarNames) {
        throw new UnsupportedOperationException();
    }

    @Override
    public final CovarianceMatrixOnTheFly getSubmatrix(String[] submatrixVarNames) {
        throw new UnsupportedOperationException();
    }

    @Override
    public final double getValue(int i, int j) {
        if (i == j) {
            return this.variances[i];
        }
        double d = 0.0;
        double[] v1 = this.vectors[i];
        double[] v2 = this.vectors[j];
        int count = 0;
        for (int k = 0; k < this.sampleSize; ++k) {
            if (Double.isNaN(v1[k]) || Double.isNaN(v2[k])) continue;
            d += v1[k] * v2[k];
            ++count;
        }
        double v = d;
        return v /= (double)(count - 1);
    }

    public final double getValue(int i, int j, int[] rows) {
        double d = 0.0;
        double[] v1 = this.vectors[i];
        double[] v2 = this.vectors[j];
        int count = 0;
        for (int k : rows) {
            if (Double.isNaN(v1[k]) || Double.isNaN(v2[k])) continue;
            d += v1[k] * v2[k];
            ++count;
        }
        double v = d;
        return v /= (double)(count - 1);
    }

    @Override
    public void setMatrix(Matrix matrix) {
        this.matrix = matrix;
        this.checkMatrix();
    }

    @Override
    public final void setSampleSize(int sampleSize) {
        if (sampleSize <= 0) {
            throw new IllegalArgumentException("Sample size must be > 0.");
        }
        this.sampleSize = sampleSize;
    }

    @Override
    public final int getSize() {
        return this.getVariables().size();
    }

    @Override
    public final Matrix getMatrix() {
        Matrix matrix = new Matrix(this.getDimension(), this.getDimension());
        for (int i = 0; i < this.getDimension(); ++i) {
            for (int j = 0; j < this.getDimension(); ++j) {
                matrix.set(i, j, this.getValue(i, j));
            }
        }
        return matrix;
    }

    public final Matrix getMatrix(int[] rows) {
        Matrix matrix = new Matrix(this.getDimension(), this.getDimension());
        for (int i = 0; i < this.getDimension(); ++i) {
            for (int j = 0; j < this.getDimension(); ++j) {
                matrix.set(i, j, this.getValue(i, j, rows));
            }
        }
        return matrix;
    }

    @Override
    public final void select(Node variable) {
        if (this.variables.contains(variable)) {
            this.getSelectedVariables().add(variable);
        }
    }

    @Override
    public final void clearSelection() {
        this.getSelectedVariables().clear();
    }

    @Override
    public final boolean isSelected(Node variable) {
        if (variable == null) {
            throw new NullPointerException("Null variable. Try again.");
        }
        return this.getSelectedVariables().contains(variable);
    }

    @Override
    public final List<String> getSelectedVariableNames() {
        LinkedList<String> selectedVariableNames = new LinkedList<String>();
        for (Node variable : this.selectedVariables) {
            selectedVariableNames.add(variable.getName());
        }
        return selectedVariableNames;
    }

    @Override
    public final String toString() {
        NumberFormat nf = NumberFormatUtil.getInstance().getNumberFormat();
        StringBuilder buf = new StringBuilder();
        int numVars = this.getVariableNames().size();
        buf.append(this.getSampleSize()).append("\n");
        for (int i = 0; i < numVars; ++i) {
            String name = this.getVariableNames().get(i);
            buf.append(name).append("\t");
        }
        buf.append("\n");
        for (int j = 0; j < numVars; ++j) {
            for (int i = 0; i <= j; ++i) {
                buf.append(nf.format(this.getValue(i, j))).append("\t");
            }
            buf.append("\n");
        }
        return buf.toString();
    }

    @Override
    public boolean isContinuous() {
        return true;
    }

    @Override
    public boolean isDiscrete() {
        return false;
    }

    @Override
    public boolean isMixed() {
        return false;
    }

    @Override
    public void setVariables(List<Node> variables) {
        if (variables.size() != this.variables.size()) {
            throw new IllegalArgumentException("Wrong # of variables.");
        }
        this.variables = variables;
    }

    public boolean isVerbose() {
        return this.verbose;
    }

    public void setVerbose(boolean verbose) {
        this.verbose = verbose;
    }

    @Override
    public Matrix getSelection(int[] rows, int[] cols) {
        Matrix m = new Matrix(rows.length, cols.length);
        if (Arrays.equals(rows, cols)) {
            for (int i = 0; i < rows.length; ++i) {
                for (int j = i; j < cols.length; ++j) {
                    double value = this.getValue(rows[i], cols[j]);
                    m.set(i, j, value);
                    m.set(j, i, value);
                }
            }
        } else {
            for (int i = 0; i < rows.length; ++i) {
                for (int j = 0; j < cols.length; ++j) {
                    double value = this.getValue(rows[i], cols[j]);
                    m.set(i, j, value);
                }
            }
        }
        return m;
    }

    public Matrix getSelection(int[] rows, int[] cols, int[] dataRows) {
        Matrix m = new Matrix(rows.length, cols.length);
        if (Arrays.equals(rows, cols)) {
            for (int i = 0; i < rows.length; ++i) {
                for (int j = i; j < cols.length; ++j) {
                    double value = this.getValue(rows[i], cols[j], dataRows);
                    m.set(i, j, value);
                    m.set(j, i, value);
                }
            }
        } else {
            for (int i = 0; i < rows.length; ++i) {
                for (int j = 0; j < cols.length; ++j) {
                    double value = this.getValue(rows[i], cols[j], dataRows);
                    m.set(i, j, value);
                }
            }
        }
        return m;
    }

    @Override
    public Node getVariable(String name) {
        for (int i = 0; i < this.getVariables().size(); ++i) {
            Node variable = this.getVariables().get(i);
            if (!name.equals(variable.getName())) continue;
            return variable;
        }
        return null;
    }

    @Override
    public DataModel copy() {
        return this;
    }

    @Override
    public void setValue(int i, int j, double v) {
        throw new IllegalArgumentException();
    }

    @Override
    public void removeVariables(List<String> remaining) {
        ICovarianceMatrix cov = this.getSubmatrix(remaining);
        this.matrix = cov.getMatrix();
        this.variables = cov.getVariables();
        this.clearSelection();
    }

    private Set<Node> getSelectedVariables() {
        return this.selectedVariables;
    }

    private void checkMatrix() {
        int numVars = this.variables.size();
        for (Node variable : this.variables) {
            if (variable != null) continue;
            throw new NullPointerException();
        }
        if (this.sampleSize < 1) {
            throw new IllegalArgumentException("Sample size must be at least 1.");
        }
        if (numVars != this.matrix.rows() || numVars != this.matrix.columns()) {
            throw new IllegalArgumentException("Number of variables does not equal the dimension of the matrix.");
        }
    }

    private void readObject(ObjectInputStream s) throws IOException, ClassNotFoundException {
        s.defaultReadObject();
        if (this.getVariables() == null) {
            throw new NullPointerException();
        }
        if (this.matrixC != null) {
            this.matrix = new Matrix(this.matrixC.toArray());
            this.matrixC = null;
        }
        if (this.knowledge == null) {
            throw new NullPointerException();
        }
        if (this.sampleSize < -1) {
            throw new IllegalStateException();
        }
        if (this.selectedVariables == null) {
            this.selectedVariables = new HashSet<Node>();
        }
    }
}

