/*
 * Decompiled with CFR 0.152.
 */
package edu.cmu.tetrad.search;

import edu.cmu.tetrad.data.CovarianceMatrix;
import edu.cmu.tetrad.data.DataSet;
import edu.cmu.tetrad.data.ICovarianceMatrix;
import edu.cmu.tetrad.util.Matrix;
import edu.cmu.tetrad.util.Vector;
import java.util.LinkedList;
import org.apache.commons.math3.util.FastMath;

public class FactorAnalysis {
    private final CovarianceMatrix covariance;
    private LinkedList<Matrix> factorLoadingVectors;
    private double threshold = 0.001;
    private int numFactors = 2;
    private Matrix residual;

    public FactorAnalysis(ICovarianceMatrix covarianceMatrix) {
        this.covariance = new CovarianceMatrix(covarianceMatrix);
    }

    public FactorAnalysis(DataSet dataSet) {
        this.covariance = new CovarianceMatrix(dataSet);
    }

    public Matrix successiveResidual() {
        boolean found;
        int i;
        this.factorLoadingVectors = new LinkedList();
        Matrix residual = this.covariance.getMatrix().copy();
        Matrix unitVector = new Matrix(residual.rows(), 1);
        for (i = 0; i < unitVector.rows(); ++i) {
            unitVector.set(i, 0, 1.0);
        }
        for (i = 0; i < this.getNumFactors() && (found = this.successiveResidualHelper(residual, unitVector)); ++i) {
            Matrix f = this.factorLoadingVectors.getLast();
            residual = residual.minus(f.times(f.transpose()));
        }
        this.factorLoadingVectors.removeFirst();
        Matrix result = new Matrix(residual.rows(), this.factorLoadingVectors.size());
        for (int i2 = 0; i2 < result.rows(); ++i2) {
            for (int j = 0; j < result.columns(); ++j) {
                result.set(i2, j, this.factorLoadingVectors.get(j).get(i2, 0));
            }
        }
        this.residual = residual;
        return result;
    }

    public Matrix successiveFactorVarimax(Matrix factorLoadingMatrix) {
        if (factorLoadingMatrix.columns() == 1) {
            return factorLoadingMatrix;
        }
        LinkedList<Matrix> residuals = new LinkedList<Matrix>();
        LinkedList<Matrix> rotatedFactorVectors = new LinkedList<Matrix>();
        Matrix normalizedFactorLoadings = FactorAnalysis.normalizeRows(factorLoadingMatrix);
        residuals.add(normalizedFactorLoadings);
        Matrix unitColumn = new Matrix(factorLoadingMatrix.rows(), 1);
        for (int i = 0; i < factorLoadingMatrix.rows(); ++i) {
            unitColumn.set(i, 0, 1.0);
        }
        Matrix r = (Matrix)residuals.getLast();
        Matrix sumCols = r.transpose().times(unitColumn);
        Matrix wVector = sumCols.scalarMult(1.0 / FastMath.sqrt(unitColumn.transpose().times(r).times(sumCols).get(0, 0)));
        Matrix vVector = r.times(wVector);
        for (int k = 0; k < normalizedFactorLoadings.columns(); ++k) {
            int lIndex = 0;
            double minValue = Double.POSITIVE_INFINITY;
            for (int i = 0; i < vVector.rows(); ++i) {
                if (!(vVector.get(i, 0) < minValue)) continue;
                minValue = vVector.get(i, 0);
                lIndex = i;
            }
            LinkedList<Matrix> hVectors = new LinkedList<Matrix>();
            LinkedList<Matrix> bVectors = new LinkedList<Matrix>();
            double alpha1 = Double.NaN;
            r = (Matrix)residuals.getLast();
            hVectors.add(new Matrix(r.columns(), 1));
            Vector rowFromFactorLoading = r.getRow(lIndex);
            for (int j = 0; j < ((Matrix)hVectors.getLast()).rows(); ++j) {
                ((Matrix)hVectors.getLast()).set(j, 0, rowFromFactorLoading.get(j));
            }
            for (int i = 0; i < 200; ++i) {
                Matrix bVector = r.times((Matrix)hVectors.get(i));
                double averageSumSquaresBVector = unitColumn.transpose().times(FactorAnalysis.matrixExp(bVector, 2.0)).scalarMult(1.0 / (double)bVector.rows()).get(0, 0);
                Matrix betaVector = FactorAnalysis.matrixExp(bVector, 3.0).minus(bVector.scalarMult(averageSumSquaresBVector));
                Matrix uVector = r.transpose().times(betaVector);
                double alpha2 = FastMath.sqrt(uVector.transpose().times(uVector).get(0, 0));
                bVectors.add(bVector);
                hVectors.add(uVector.scalarMult(1.0 / alpha2));
                if (!Double.isNaN(alpha1) && FastMath.abs(alpha2 - alpha1) < this.getThreshold()) break;
                alpha1 = alpha2;
            }
            Matrix b = (Matrix)bVectors.getLast();
            rotatedFactorVectors.add(b);
            residuals.add(r.minus(b.times(((Matrix)hVectors.getLast()).transpose())));
        }
        Matrix result = factorLoadingMatrix.like();
        if (!rotatedFactorVectors.isEmpty()) {
            for (int i = 0; i < ((Matrix)rotatedFactorVectors.get(0)).rows(); ++i) {
                for (int j = 0; j < rotatedFactorVectors.size(); ++j) {
                    result.set(i, j, ((Matrix)rotatedFactorVectors.get(j)).get(i, 0));
                }
            }
        }
        return result;
    }

    public void setThreshold(double threshold) {
        this.threshold = threshold;
    }

    private boolean successiveResidualHelper(Matrix residual, Matrix approximationVector) {
        Matrix l0 = approximationVector.transpose().times(residual).times(approximationVector);
        if (l0.get(0, 0) < 0.0) {
            return false;
        }
        double d = FastMath.sqrt(l0.get(0, 0));
        Matrix f = residual.times(approximationVector).scalarMult(1.0 / d);
        for (int i = 0; i < 100; ++i) {
            Matrix ui = residual.times(f);
            Matrix li = f.transpose().times(ui);
            double di = FastMath.sqrt(li.get(0, 0));
            if (FastMath.abs(d - di) <= this.getThreshold()) break;
            d = di;
            f = ui.scalarMult(1.0 / d);
        }
        this.factorLoadingVectors.add(f);
        return true;
    }

    private static Matrix normalizeRows(Matrix matrix) {
        int j;
        LinkedList<Matrix> normalizedRows = new LinkedList<Matrix>();
        for (int i = 0; i < matrix.rows(); ++i) {
            Vector vector = matrix.getRow(i);
            Matrix colVector = new Matrix(matrix.columns(), 1);
            for (j = 0; j < matrix.columns(); ++j) {
                colVector.set(j, 0, vector.get(j));
            }
            normalizedRows.add(FactorAnalysis.normalizeVector(colVector));
        }
        Matrix result = new Matrix(matrix.rows(), matrix.columns());
        for (int i = 0; i < matrix.rows(); ++i) {
            Matrix normalizedRow = (Matrix)normalizedRows.get(i);
            for (j = 0; j < matrix.columns(); ++j) {
                result.set(i, j, normalizedRow.get(j, 0));
            }
        }
        return result;
    }

    private static Matrix normalizeVector(Matrix vector) {
        double scalar = FastMath.sqrt(vector.transpose().times(vector).get(0, 0));
        return vector.scalarMult(1.0 / scalar);
    }

    private static Matrix matrixExp(Matrix matrix, double exponent) {
        Matrix result = new Matrix(matrix.rows(), matrix.columns());
        for (int i = 0; i < matrix.rows(); ++i) {
            for (int j = 0; j < matrix.columns(); ++j) {
                result.set(i, j, FastMath.pow(matrix.get(i, j), exponent));
            }
        }
        return result;
    }

    public double getThreshold() {
        return this.threshold;
    }

    public int getNumFactors() {
        return this.numFactors;
    }

    public void setNumFactors(int numFactors) {
        this.numFactors = numFactors;
    }

    public Matrix getResidual() {
        return this.residual;
    }
}

