/*
 * Decompiled with CFR 0.152.
 */
package edu.cmu.tetrad.data;

import cern.colt.matrix.DoubleMatrix2D;
import cern.colt.matrix.impl.DenseDoubleMatrix2D;
import edu.cmu.tetrad.data.ContinuousVariable;
import edu.cmu.tetrad.data.DataModel;
import edu.cmu.tetrad.data.DataSet;
import edu.cmu.tetrad.data.Knowledge;
import edu.cmu.tetrad.graph.Node;
import edu.cmu.tetrad.util.MatrixUtils;
import edu.cmu.tetrad.util.NumberFormatUtil;
import edu.cmu.tetrad.util.TetradSerializable;
import java.io.IOException;
import java.io.ObjectInputStream;
import java.text.NumberFormat;
import java.util.ArrayList;
import java.util.Collections;
import java.util.HashSet;
import java.util.LinkedList;
import java.util.List;
import java.util.Set;

public class CovarianceMatrix
implements DataModel,
TetradSerializable {
    static final long serialVersionUID = 23L;
    private String name;
    private List<Node> variables;
    private int sampleSize;
    private double[][] matrix;
    private DoubleMatrix2D matrixC;
    private Set<Node> selectedVariables = new HashSet<Node>();
    private Knowledge knowledge = new Knowledge();

    public CovarianceMatrix(DataSet dataSet) {
        if (!dataSet.isContinuous()) {
            throw new IllegalArgumentException("Not a continuous data set.");
        }
        this.variables = Collections.unmodifiableList(dataSet.getVariables());
        this.sampleSize = dataSet.getNumRows();
        this.matrixC = dataSet.getCovarianceMatrix();
    }

    public CovarianceMatrix(List<Node> variables, DoubleMatrix2D matrix, int sampleSize) {
        this.variables = Collections.unmodifiableList(variables);
        this.sampleSize = sampleSize;
        this.matrixC = matrix.copy();
        this.checkMatrix();
    }

    public CovarianceMatrix(CovarianceMatrix covMatrix) {
        this(covMatrix.variables, covMatrix.matrixC.copy(), covMatrix.sampleSize);
    }

    public static CovarianceMatrix serializableInstance() {
        ArrayList<Node> variables = new ArrayList<Node>();
        ContinuousVariable x = new ContinuousVariable("X");
        variables.add(x);
        DoubleMatrix2D matrix = MatrixUtils.identityC(1);
        return new CovarianceMatrix(variables, matrix, 100);
    }

    @Override
    public final List<Node> getVariables() {
        return this.variables;
    }

    @Override
    public final List<String> getVariableNames() {
        ArrayList<String> names = new ArrayList<String>();
        for (int i = 0; i < this.getVariables().size(); ++i) {
            Node variable = this.getVariables().get(i);
            names.add(variable.getName());
        }
        return names;
    }

    public final String getVariableName(int index) {
        if (index >= this.getVariables().size()) {
            throw new IllegalArgumentException("Index out of range: " + index);
        }
        Node variable = this.getVariables().get(index);
        return variable.getName();
    }

    public final int getDimension() {
        return this.variables.size();
    }

    public final int getSampleSize() {
        return this.sampleSize;
    }

    @Override
    public final String getName() {
        return this.name;
    }

    @Override
    public final void setName(String name) {
        this.name = name;
    }

    @Override
    public final Knowledge getKnowledge() {
        return new Knowledge(this.knowledge);
    }

    @Override
    public final void setKnowledge(Knowledge knowledge) {
        if (knowledge == null) {
            throw new NullPointerException();
        }
        this.knowledge = new Knowledge(knowledge);
    }

    public final CovarianceMatrix getSubmatrix(int[] indices) {
        LinkedList<Node> submatrixVars = new LinkedList<Node>();
        for (int indice : indices) {
            submatrixVars.add(this.variables.get(indice));
        }
        DoubleMatrix2D cov = this.matrixC.viewSelection(indices, indices);
        return new CovarianceMatrix(submatrixVars, cov, this.getSampleSize());
    }

    public final CovarianceMatrix getSubmatrix(List<String> submatrixVarNames) {
        String[] varNames = new String[submatrixVarNames.size()];
        for (int i = 0; i < submatrixVarNames.size(); ++i) {
            varNames[i] = submatrixVarNames.get(i);
        }
        return this.getSubmatrix(varNames);
    }

    public final CovarianceMatrix getSubmatrix(String[] submatrixVarNames) {
        LinkedList<Node> submatrixVars = new LinkedList<Node>();
        for (String submatrixVarName : submatrixVarNames) {
            submatrixVars.add(this.getVariable(submatrixVarName));
        }
        if (!this.getVariables().containsAll(submatrixVars)) {
            throw new IllegalArgumentException("The variables in the submatrix must be in the original matrix: original==" + this.getVariables() + ", sub==" + submatrixVars);
        }
        for (int i = 0; i < submatrixVars.size(); ++i) {
            if (submatrixVars.get(i) != null) continue;
            throw new NullPointerException("The variable name at index " + i + " is null.");
        }
        int[] indices = new int[submatrixVars.size()];
        for (int i = 0; i < indices.length; ++i) {
            indices[i] = this.getVariables().indexOf(submatrixVars.get(i));
        }
        DoubleMatrix2D cov = this.matrixC.viewSelection(indices, indices);
        return new CovarianceMatrix(submatrixVars, cov, this.getSampleSize());
    }

    public final double getValue(int i, int j) {
        return this.matrixC.getQuick(i, j);
    }

    public void setMatrix(DoubleMatrix2D matrix) {
        this.matrixC = matrix.copy();
        this.checkMatrix();
    }

    public final void setSampleSize(int sampleSize) {
        if (sampleSize <= 0) {
            throw new IllegalArgumentException("Sample size must be > 0.");
        }
        this.sampleSize = sampleSize;
    }

    public final int getSize() {
        return this.matrixC.rows();
    }

    public final DoubleMatrix2D getMatrix() {
        return this.matrixC.copy();
    }

    public final void select(Node variable) {
        if (this.variables.contains(variable)) {
            this.getSelectedVariables().add(variable);
        }
    }

    public final void clearSelection() {
        this.getSelectedVariables().clear();
    }

    public final boolean isSelected(Node variable) {
        if (variable == null) {
            throw new NullPointerException("Null variable. Try again.");
        }
        return this.getSelectedVariables().contains(variable);
    }

    public final List<String> getSelectedVariableNames() {
        LinkedList<String> selectedVariableNames = new LinkedList<String>();
        for (Node variable : this.selectedVariables) {
            selectedVariableNames.add(variable.getName());
        }
        return selectedVariableNames;
    }

    @Override
    public final String toString() {
        NumberFormat nf = NumberFormatUtil.getInstance().getNumberFormat();
        StringBuilder buf = new StringBuilder();
        int numVars = this.getVariableNames().size();
        buf.append(this.getSampleSize()).append("\n");
        for (int i = 0; i < numVars; ++i) {
            String name = this.getVariableNames().get(i);
            buf.append(name + "\t");
        }
        buf.append("\n");
        for (int j = 0; j < numVars; ++j) {
            for (int i = 0; i <= j; ++i) {
                buf.append(nf.format(this.getValue(i, j)) + "\t");
            }
            buf.append("\n");
        }
        return buf.toString();
    }

    public Node getVariable(String name) {
        for (int i = 0; i < this.getVariables().size(); ++i) {
            Node variable = this.getVariables().get(i);
            if (!name.equals(variable.getName())) continue;
            return variable;
        }
        return null;
    }

    private Set<Node> getSelectedVariables() {
        return this.selectedVariables;
    }

    private void checkMatrix() {
        int numVars = this.variables.size();
        for (int i = 0; i < numVars; ++i) {
            if (this.variables.get(i) != null) continue;
            throw new NullPointerException();
        }
        if (this.sampleSize < 1) {
            throw new IllegalArgumentException("Sample size must be at least 1.");
        }
        if (numVars != this.matrixC.rows() || numVars != this.matrixC.columns()) {
            throw new IllegalArgumentException("Number of variables does not equal the dimension of the matrix.");
        }
        if (this.sampleSize <= this.matrixC.rows()) {
            System.out.println("Covariance matrix cannot be positive definite since \nthere are more variables than sample points. Spot-checking \nsome submatrices.");
            for (int from = 0; from < numVars; from += this.sampleSize / 2) {
                int to = from + this.sampleSize / 2;
                if (to > numVars) {
                    to = numVars;
                }
                int[] indices = new int[to - from];
                for (int i = 0; i < indices.length; ++i) {
                    indices[i] = from + i;
                }
                DoubleMatrix2D m2 = this.matrixC.viewSelection(indices, indices);
                if (MatrixUtils.isPositiveDefinite(m2)) continue;
                System.out.println("Positive definite spot-check failed.");
            }
        }
    }

    private void readObject(ObjectInputStream s) throws IOException, ClassNotFoundException {
        s.defaultReadObject();
        if (this.getVariables() == null) {
            throw new NullPointerException();
        }
        if (this.matrix != null) {
            this.matrixC = new DenseDoubleMatrix2D(this.matrix);
            this.matrix = null;
        }
        if (this.knowledge == null) {
            throw new NullPointerException();
        }
        if (this.sampleSize < -1) {
            throw new IllegalStateException();
        }
        if (this.selectedVariables == null) {
            this.selectedVariables = new HashSet<Node>();
        }
    }
}

