/*
 * Decompiled with CFR 0.152.
 */
package edu.cmu.tetrad.search.fastica;

import edu.cmu.tetrad.search.fastica.BelowEVFilter;
import edu.cmu.tetrad.search.fastica.CompositeEVFilter;
import edu.cmu.tetrad.search.fastica.ContrastFunction;
import edu.cmu.tetrad.search.fastica.EigenValueFilter;
import edu.cmu.tetrad.search.fastica.FastICAConfig;
import edu.cmu.tetrad.search.fastica.FastICAException;
import edu.cmu.tetrad.search.fastica.PCA;
import edu.cmu.tetrad.search.fastica.Power3CFunction;
import edu.cmu.tetrad.search.fastica.ProgressListener;
import edu.cmu.tetrad.search.fastica.SortingEVFilter;
import edu.cmu.tetrad.search.fastica.TanhCFunction;
import edu.cmu.tetrad.search.fastica.math.EigenValueDecompositionSymm;
import edu.cmu.tetrad.search.fastica.math.Matrix;
import edu.cmu.tetrad.search.fastica.math.Vector;
import edu.cmu.tetrad.search.fastica.util.AudioBuffer;
import java.io.File;
import javax.sound.sampled.AudioFileFormat;
import javax.sound.sampled.AudioInputStream;
import javax.sound.sampled.AudioSystem;

public class FastICA {
    private double[][] inVectors;
    private double[] meanValues;
    private double[][] vectorsZeroMean;
    private double[][] whiteningMatrix;
    private double[][] dewhiteningMatrix;
    private double[][] whitenedVectors;
    private double[][] weightMatrix;
    private double[][] mixingMatrix;
    private double[][] separatingMatrix;
    private double[][] icVectors;
    boolean fakeRandom = false;

    public double[][] getWeightMatrix() {
        return this.weightMatrix;
    }

    public FastICA(double[][] inVectors, int numICs) throws FastICAException {
        this.algorithm(inVectors, new FastICAConfig(numICs, FastICAConfig.Approach.SYMMETRIC, 1.0, 1.0E-12, 1000, null), new TanhCFunction(1.0), new BelowEVFilter(1.0E-12, false), new ProgressListener(){

            @Override
            public void progressMade(ProgressListener.ComputationState state, int component, int iteration, int maxComps) {
            }
        });
    }

    public FastICA(double[][] inVectors, FastICAConfig config, ContrastFunction conFunction, EigenValueFilter evFilter, ProgressListener listener) throws FastICAException {
        this.algorithm(inVectors, config, conFunction, evFilter, listener);
    }

    private synchronized void algorithm(double[][] inVectors, FastICAConfig config, ContrastFunction conFunction, EigenValueFilter evFilter, ProgressListener listener) throws FastICAException {
        listener.progressMade(ProgressListener.ComputationState.WHITENING, 0, 0, config.getNumICs());
        this.inVectors = inVectors;
        this.icVectors = null;
        PCA pca = new PCA(inVectors);
        this.meanValues = pca.getMeanValues();
        this.vectorsZeroMean = pca.getVectorsZeroMean();
        double[] eigenValues = pca.getEigenValues();
        double[][] eigenVectors = pca.getEigenVectors();
        evFilter.passEigenValues(eigenValues, eigenVectors);
        eigenValues = evFilter.getEigenValues();
        if (eigenValues == null || eigenValues.length == 0) {
            this.mixingMatrix = null;
            this.separatingMatrix = null;
            this.icVectors = null;
            throw new FastICAException(FastICAException.Reason.NO_MORE_EIGENVALUES);
        }
        eigenVectors = evFilter.getEigenVectors();
        this.whiteningMatrix = Matrix.mult(Matrix.diag(FastICA.invVector(FastICA.sqrtVector(eigenValues))), Matrix.transpose(eigenVectors));
        this.dewhiteningMatrix = Matrix.mult(eigenVectors, Matrix.diag(FastICA.sqrtVector(eigenValues)));
        this.whitenedVectors = Matrix.mult(this.whiteningMatrix, this.vectorsZeroMean);
        int m = Matrix.getNumOfRows(this.whitenedVectors);
        int n = Matrix.getNumOfColumns(this.whitenedVectors);
        int numICs = config.getNumICs();
        if (m < numICs) {
            numICs = m;
        }
        this.weightMatrix = config.getInitialMixingMatrix() == null ? this.makeRandomWeightMatrix(numICs, m) : (Matrix.getNumOfColumns(config.getInitialMixingMatrix()) == numICs && Matrix.getNumOfRows(config.getInitialMixingMatrix()) == Matrix.getNumOfRows(this.vectorsZeroMean) ? Matrix.transpose(Matrix.mult(this.whiteningMatrix, config.getInitialMixingMatrix())) : Matrix.random(numICs, m));
        this.weightMatrix = Matrix.mult(FastICA.powerSymmMatrix(Matrix.square(this.weightMatrix), -0.5), this.weightMatrix);
        int maxIter = config.getMaxIterations();
        switch (config.getApproach()) {
            case SYMMETRIC: {
                boolean ready = false;
                for (int iter = 0; iter < maxIter && !ready; ++iter) {
                    listener.progressMade(ProgressListener.ComputationState.SYMMETRIC, 0, iter, numICs);
                    double[][] weightMatrixOld = Matrix.clone(this.weightMatrix);
                    for (int i = 0; i < numICs; ++i) {
                        int j;
                        double[] v1 = Matrix.getVecOfRow(this.weightMatrix, i);
                        double beta = 0.0;
                        double[] exgf = Vector.newVector(m, 0.0);
                        double egfd = 0.0;
                        for (j = 0; j < n; ++j) {
                            double[] actualXVector = Matrix.getVecOfCol(this.whitenedVectors, j);
                            double weightedX = Vector.dot(v1, actualXVector);
                            double gff = conFunction.function(weightedX);
                            double gfd = conFunction.derivative(weightedX);
                            beta += weightedX * gff;
                            egfd += gfd;
                            exgf = Vector.add(exgf, Vector.scale(gff, actualXVector));
                        }
                        exgf = Vector.scale(1.0 / (double)n, exgf);
                        double[] v2 = Vector.sub(v1, Vector.scale(1.0 / ((egfd /= (double)n) - (beta /= (double)n)), Vector.sub(exgf, Vector.scale(beta, v1))));
                        for (j = 0; j < m; ++j) {
                            this.weightMatrix[i][j] = v2[j];
                        }
                    }
                    this.weightMatrix = Matrix.mult(FastICA.powerSymmMatrix(Matrix.square(this.weightMatrix), -0.5), this.weightMatrix);
                    if (!(FastICA.deltaMatrices(this.weightMatrix, weightMatrixOld) < config.getEpsilon())) continue;
                    ready = true;
                }
                break;
            }
            case DEFLATION: {
                for (int i = 0; i < numICs; ++i) {
                    double[] v2 = Matrix.getVecOfRow(this.weightMatrix, i);
                    boolean ready = false;
                    for (int iter = 0; iter < maxIter && !ready; ++iter) {
                        int j;
                        listener.progressMade(ProgressListener.ComputationState.DEFLATION, i, iter, numICs);
                        double[] v1 = Vector.clone(v2);
                        double beta = 0.0;
                        double[] exgf = Vector.newVector(m, 0.0);
                        double egfd = 0.0;
                        for (j = 0; j < n; ++j) {
                            double[] actualXVector = Matrix.getVecOfCol(this.whitenedVectors, j);
                            double weightedX = Vector.dot(v1, actualXVector);
                            double gff = conFunction.function(weightedX);
                            double gfd = conFunction.derivative(weightedX);
                            beta += weightedX * gff;
                            egfd += gfd;
                            exgf = Vector.add(exgf, Vector.scale(gff, actualXVector));
                        }
                        exgf = Vector.scale(1.0 / (double)n, exgf);
                        v2 = Vector.sub(v1, Vector.scale(1.0 / ((egfd /= (double)n) - (beta /= (double)n)), Vector.sub(exgf, Vector.scale(beta, v1))));
                        for (j = 0; j < i; ++j) {
                            v2 = Vector.sub(v2, Vector.scale(Vector.dot(v2, this.weightMatrix[j]), this.weightMatrix[j]));
                        }
                        v2 = Vector.scale(1.0 / Math.sqrt(Vector.dot(v2, v2)), v2);
                        for (j = 0; j < m; ++j) {
                            this.weightMatrix[i][j] = v2[j];
                        }
                        if (!(FastICA.deltaVectors(v2, v1) < config.getEpsilon())) continue;
                        ready = true;
                    }
                }
                break;
            }
        }
        this.mixingMatrix = Matrix.mult(this.dewhiteningMatrix, Matrix.transpose(this.weightMatrix));
        this.separatingMatrix = Matrix.mult(this.weightMatrix, this.whiteningMatrix);
        listener.progressMade(ProgressListener.ComputationState.READY, numICs, maxIter, numICs);
    }

    private double[][] makeRandomWeightMatrix(int numICs, int m) {
        Object matres = new double[m][];
        if (!this.fakeRandom) {
            matres = Matrix.random(numICs, m);
        } else {
            matres[0] = new double[2];
            matres[1] = new double[2];
            matres[0][0] = 0.068367;
            matres[0][1] = 0.482512;
            matres[1][0] = 0.544383;
            matres[1][1] = 0.343396;
        }
        return matres;
    }

    private static double deltaMatrices(double[][] mat1, double[][] mat2) {
        double[][] test = Matrix.sub(mat1, mat2);
        double delta = 0.0;
        int m = Matrix.getNumOfRows(mat1);
        int n = Matrix.getNumOfColumns(mat1);
        for (int i = 0; i < m; ++i) {
            for (int j = 0; j < n; ++j) {
                delta += Math.abs(test[i][j]);
            }
        }
        return delta / (double)(m * n);
    }

    private static double deltaVectors(double[] vec1, double[] vec2) {
        double[] test = Vector.sub(vec1, vec2);
        double delta = 0.0;
        int m = vec1.length;
        for (int i = 0; i < m; ++i) {
            delta += Math.abs(test[i]);
        }
        return delta / (double)m;
    }

    private static double[][] powerSymmMatrix(double[][] inMatrix, double power) {
        EigenValueDecompositionSymm eigenDeco = new EigenValueDecompositionSymm(inMatrix);
        int m = Matrix.getNumOfRows(inMatrix);
        double[][] eigenVectors = eigenDeco.getEigenVectors();
        double[] eigenValues = eigenDeco.getEigenValues();
        for (int i = 0; i < m; ++i) {
            eigenValues[i] = Math.pow(eigenValues[i], power);
        }
        return Matrix.mult(Matrix.mult(eigenVectors, Matrix.diag(eigenValues)), Matrix.transpose(eigenVectors));
    }

    private static double[] invVector(double[] inVector) {
        int m = inVector.length;
        double[] outVector = new double[m];
        for (int i = 0; i < m; ++i) {
            outVector[i] = 1.0 / inVector[i];
        }
        return outVector;
    }

    private static double[] sqrtVector(double[] inVector) {
        int m = inVector.length;
        double[] outVector = new double[m];
        for (int i = 0; i < m; ++i) {
            outVector[i] = Math.sqrt(inVector[i]);
        }
        return outVector;
    }

    public synchronized double[][] getICVectors() {
        if (this.icVectors == null) {
            this.icVectors = Matrix.mult(this.separatingMatrix, this.inVectors);
        }
        return this.icVectors;
    }

    public double[][] getMixingMatrix() {
        return this.mixingMatrix;
    }

    public double[][] getSeparatingMatrix() {
        return this.separatingMatrix;
    }

    public static void main(String[] args) {
        if (args.length != 3) {
            System.out.println("Usage:");
            System.out.println("java edu.cmu.tetrad.search.fastica.FastICA [input wave] [number of independent components] [output wave]");
            System.out.println();
            return;
        }
        try {
            AudioBuffer buffer1 = new AudioBuffer(new File(args[0]));
            double[][] mixingMatrix = Matrix.newMatrix(5, 2);
            mixingMatrix[0][0] = 0.5;
            mixingMatrix[0][1] = 0.5;
            mixingMatrix[1][0] = 0.3;
            mixingMatrix[1][1] = 0.7;
            mixingMatrix[2][0] = 0.6;
            mixingMatrix[2][1] = 0.2;
            mixingMatrix[3][0] = 0.2;
            mixingMatrix[3][1] = 0.6;
            mixingMatrix[4][0] = 0.3;
            mixingMatrix[4][1] = 0.5;
            double[][] mixedSignal = Matrix.mult(mixingMatrix, buffer1.getData());
            CompositeEVFilter filter = new CompositeEVFilter();
            filter.add(new BelowEVFilter(1.0E-8, false));
            filter.add(new SortingEVFilter(true, true));
            FastICAConfig config = new FastICAConfig(Integer.parseInt(args[1]), FastICAConfig.Approach.SYMMETRIC, 1.0, 1.0E-16, 1000, null);
            ProgressListener listener = new ProgressListener(){

                @Override
                public void progressMade(ProgressListener.ComputationState state, int component, int iteration, int maxComps) {
                    System.out.print("\r" + Integer.toString(component) + " - " + Integer.toString(iteration) + " ");
                }
            };
            System.out.println("Performing ICA");
            FastICA fica = new FastICA(mixedSignal, config, new Power3CFunction(), filter, listener);
            System.out.println();
            AudioBuffer buffer2 = new AudioBuffer(fica.getICVectors(), buffer1.getSampleRate());
            AudioInputStream stream2 = buffer2.getStream();
            AudioSystem.write(stream2, AudioFileFormat.Type.WAVE, new File(args[2]));
        }
        catch (Exception exc) {
            exc.printStackTrace(System.err);
        }
    }
}

