/*
 * Decompiled with CFR 0.152.
 */
package edu.cmu.tetrad.search;

import cern.colt.function.DoubleFunction;
import cern.colt.list.DoubleArrayList;
import cern.colt.matrix.DoubleFactory2D;
import cern.colt.matrix.DoubleMatrix1D;
import cern.colt.matrix.DoubleMatrix2D;
import cern.colt.matrix.impl.DenseDoubleMatrix1D;
import cern.colt.matrix.impl.DenseDoubleMatrix2D;
import cern.colt.matrix.linalg.Algebra;
import cern.colt.matrix.linalg.SingularValueDecomposition;
import cern.jet.math.Mult;
import cern.jet.math.PlusMult;
import cern.jet.stat.Descriptive;
import edu.cmu.tetrad.util.RandomUtil;

public class FastIca {
    public static int PARALLEL = 0;
    public static int DEFLATION = 1;
    public static int LOGCOSH = 2;
    public static int EXP = 3;
    private DoubleMatrix2D X;
    private int numComponents;
    private int algorithmType = PARALLEL;
    private int function = LOGCOSH;
    private double alpha = 1.0;
    private boolean rowNorm = false;
    private int maxIterations = 200;
    private double tolerance = 1.0E-4;
    private boolean verbose = true;
    private DoubleMatrix2D wInit = null;

    public FastIca(DoubleMatrix2D X, int numComponents) {
        this.X = X;
        this.numComponents = numComponents;
    }

    public int getAlgorithmType() {
        return this.algorithmType;
    }

    public void setAlgorithmType(int algorithmType) {
        if (algorithmType != DEFLATION && algorithmType != PARALLEL) {
            throw new IllegalArgumentException("Value should be DEFLATION or PARALLEL.");
        }
        this.algorithmType = algorithmType;
    }

    public int getFunction() {
        return this.function;
    }

    public void setFunction(int function) {
        if (function != LOGCOSH && function != EXP) {
            throw new IllegalArgumentException("Value should be LOGCOSH or EXP.");
        }
        this.function = function;
    }

    public double getAlpha() {
        return this.alpha;
    }

    public void setAlpha(double alpha) {
        if (!(alpha >= 1.0) || !(alpha <= 2.0)) {
            throw new IllegalArgumentException("Alpha should be in range [1, 2].");
        }
        this.alpha = alpha;
    }

    public boolean isRowNorm() {
        return this.rowNorm;
    }

    public void setRowNorm(boolean rowNorm) {
        this.rowNorm = rowNorm;
    }

    public int getMaxIterations() {
        return this.maxIterations;
    }

    public void setMaxIterations(int maxIterations) {
        if (maxIterations < 1) {
            System.out.println("maxIterations should be positive.");
        }
        this.maxIterations = maxIterations;
    }

    public double getTolerance() {
        return this.tolerance;
    }

    public void setTolerance(double tolerance) {
        if (!(tolerance > 0.0)) {
            System.out.println("Tolerance should be positive.");
        }
        this.tolerance = tolerance;
    }

    public boolean isVerbose() {
        return this.verbose;
    }

    public void setVerbose(boolean verbose) {
        this.verbose = verbose;
    }

    public DoubleMatrix2D getWInit() {
        return this.wInit;
    }

    public void setWInit(DoubleMatrix2D wInit) {
        this.wInit = wInit;
    }

    public IcaResult findComponents() {
        DoubleMatrix2D b;
        int p;
        Algebra alg = new Algebra();
        int n = this.X.rows();
        if (this.numComponents > Math.min(n, p = this.X.columns())) {
            System.out.println("numComponents is too large.");
            System.out.println("numComponents set to " + Math.min(n, p));
            this.numComponents = Math.min(n, p);
        }
        if (this.wInit == null) {
            this.wInit = new DenseDoubleMatrix2D(this.numComponents, this.numComponents);
            for (int i = 0; i < this.wInit.rows(); ++i) {
                for (int j = 0; j < this.wInit.columns(); ++j) {
                    this.wInit.set(i, j, RandomUtil.getInstance().nextNormal(0.0, 1.0));
                }
            }
        } else if (this.wInit.rows() != this.wInit.columns()) {
            throw new IllegalArgumentException("wInit is the wrong size.");
        }
        if (this.verbose) {
            System.out.println("Centering");
        }
        this.X = this.scale(this.X, false);
        this.X = this.rowNorm ? this.scale(this.X, true).viewDice() : this.X.viewDice();
        if (this.verbose) {
            System.out.println("Whitening");
        }
        DoubleMatrix2D V = alg.mult(this.X, this.X.viewDice());
        V.assign(Mult.div(n));
        SingularValueDecomposition s = new SingularValueDecomposition(V);
        DoubleMatrix2D D = s.getS();
        for (int i = 0; i < D.rows(); ++i) {
            D.set(i, i, 1.0 / Math.sqrt(D.get(i, i)));
        }
        DoubleMatrix2D K = alg.mult(D, s.getU().viewDice());
        K = K.assign(Mult.mult(-1.0));
        K = K.viewPart(0, 0, this.numComponents, p);
        DoubleMatrix2D X1 = alg.mult(K, this.X);
        if (this.algorithmType == DEFLATION) {
            b = this.icaDeflation(X1, this.numComponents, this.tolerance, this.function, this.alpha, this.maxIterations, this.verbose, this.wInit);
        } else if (this.algorithmType == PARALLEL) {
            b = this.icaParallel(X1, this.numComponents, this.tolerance, this.function, this.alpha, this.maxIterations, this.verbose, this.wInit);
        } else {
            throw new IllegalStateException();
        }
        DoubleMatrix2D w = alg.mult(b, K);
        DoubleMatrix2D S = alg.mult(w, this.X);
        DoubleMatrix2D A = alg.mult(w.viewDice(), alg.inverse(alg.mult(w, w.viewDice())));
        return new IcaResult(this.X.viewDice(), K.viewDice(), b.viewDice(), A.viewDice(), S.viewDice());
    }

    private DoubleMatrix2D icaDeflation(DoubleMatrix2D X, int numComponents, double tolerance, int function, double alpha, int maxIterations, boolean verbose, DoubleMatrix2D wInit) {
        if (verbose && function == LOGCOSH) {
            System.out.println("Deflation FastIca using lgcosh approx. to neg-entropy function");
        }
        if (verbose && function == EXP) {
            System.out.println("Deflation FastIca using exponential approx. to neg-entropy function");
        }
        int p = X.columns();
        DenseDoubleMatrix2D W = new DenseDoubleMatrix2D(numComponents, numComponents);
        for (int i = 0; i < numComponents; ++i) {
            int k;
            double _rms;
            int j;
            int j2;
            int u;
            DoubleMatrix1D w1;
            double meanGwx;
            DoubleMatrix1D v2;
            DenseDoubleMatrix1D g_wx;
            DenseDoubleMatrix1D v1;
            DenseDoubleMatrix2D xgwx;
            DenseDoubleMatrix2D gwx;
            DenseDoubleMatrix1D gwx0;
            DoubleMatrix1D wx;
            if (verbose) {
                System.out.println("Component " + (i + 1));
            }
            DoubleMatrix1D w = wInit.viewRow(i).copy();
            if (i > 0) {
                DoubleMatrix1D t = w.copy().assign(0.0);
                for (int u2 = 0; u2 < i; ++u2) {
                    int j3;
                    double k2 = 0.0;
                    for (j3 = 0; j3 < numComponents; ++j3) {
                        k2 += w.get(j3) * W.get(u2, j3);
                    }
                    for (j3 = 0; j3 < numComponents; ++j3) {
                        t.set(j3, t.get(j3) + k2 * W.get(u2, j3));
                    }
                }
                for (int j4 = 0; j4 < numComponents; ++j4) {
                    w.set(j4, w.get(j4) - t.get(j4));
                }
            }
            double rms = this.rms(w);
            for (int j5 = 0; j5 < numComponents; ++j5) {
                w.set(j5, w.get(j5) / rms);
            }
            int it = 0;
            double _tolerance = Double.POSITIVE_INFINITY;
            if (function == LOGCOSH) {
                while (_tolerance > tolerance && ++it <= maxIterations) {
                    wx = new Algebra().mult(X.viewDice(), w);
                    gwx0 = new DenseDoubleMatrix1D(p);
                    for (int j6 = 0; j6 < p; ++j6) {
                        gwx0.set(j6, Math.tanh(alpha * wx.get(j6)));
                    }
                    gwx = new DenseDoubleMatrix2D(numComponents, p);
                    for (int _i = 0; _i < numComponents; ++_i) {
                        for (int j7 = 0; j7 < p; ++j7) {
                            gwx.set(_i, j7, gwx0.get(j7));
                        }
                    }
                    xgwx = new DenseDoubleMatrix2D(numComponents, p);
                    for (int _i = 0; _i < numComponents; ++_i) {
                        for (int j8 = 0; j8 < p; ++j8) {
                            xgwx.set(_i, j8, X.get(_i, j8) * gwx0.get(j8));
                        }
                    }
                    v1 = new DenseDoubleMatrix1D(numComponents);
                    for (int k3 = 0; k3 < numComponents; ++k3) {
                        v1.set(k3, this.mean(xgwx.viewRow(k3)));
                    }
                    g_wx = new DenseDoubleMatrix1D(p);
                    for (int k4 = 0; k4 < p; ++k4) {
                        double tmp1 = Math.tanh(alpha * wx.get(k4));
                        g_wx.set(k4, alpha * (1.0 - tmp1 * tmp1));
                    }
                    v2 = w.copy();
                    meanGwx = this.mean(g_wx);
                    v2.assign(Mult.mult(meanGwx));
                    w1 = v1.copy();
                    w1.assign(v2, PlusMult.plusMult(-1.0));
                    if (i > 0) {
                        DoubleMatrix1D t = w1.copy().assign(0.0);
                        for (u = 0; u < i; ++u) {
                            double k5 = 0.0;
                            for (j2 = 0; j2 < numComponents; ++j2) {
                                k5 += w1.get(j2) * W.get(u, j2);
                            }
                            for (j2 = 0; j2 < numComponents; ++j2) {
                                t.set(j2, t.get(j2) + k5 * W.get(u, j2));
                            }
                        }
                        for (j = 0; j < numComponents; ++j) {
                            w1.set(j, w1.get(j) - t.get(j));
                        }
                    }
                    _rms = this.rms(w1);
                    for (k = 0; k < numComponents; ++k) {
                        w1.set(k, w1.get(k) / _rms);
                    }
                    _tolerance = 0.0;
                    for (k = 0; k < numComponents; ++k) {
                        _tolerance += w1.get(k) * w.get(k);
                    }
                    _tolerance = Math.abs(Math.abs(_tolerance) - 1.0);
                    if (verbose) {
                        System.out.println("Iteration " + it + " tol = " + _tolerance);
                    }
                    w = w1;
                }
            } else if (function == EXP) {
                while (_tolerance > tolerance && ++it <= maxIterations) {
                    wx = new Algebra().mult(X.viewDice(), w);
                    gwx0 = new DenseDoubleMatrix1D(p);
                    for (int j9 = 0; j9 < p; ++j9) {
                        gwx0.set(j9, wx.get(j9) * Math.exp(-(wx.get(j9) * wx.get(j9)) / 2.0));
                    }
                    gwx = new DenseDoubleMatrix2D(numComponents, p);
                    for (int _i = 0; _i < numComponents; ++_i) {
                        for (int j10 = 0; j10 < p; ++j10) {
                            gwx.set(_i, j10, gwx0.get(j10));
                        }
                    }
                    xgwx = new DenseDoubleMatrix2D(numComponents, p);
                    for (int _i = 0; _i < numComponents; ++_i) {
                        for (int j11 = 0; j11 < p; ++j11) {
                            xgwx.set(_i, j11, X.get(_i, j11) * gwx0.get(j11));
                        }
                    }
                    v1 = new DenseDoubleMatrix1D(numComponents);
                    for (int k6 = 0; k6 < numComponents; ++k6) {
                        v1.set(k6, this.mean(xgwx.viewRow(k6)));
                    }
                    g_wx = new DenseDoubleMatrix1D(p);
                    for (int j12 = 0; j12 < p; ++j12) {
                        g_wx.set(j12, (1.0 - wx.get(j12) * wx.get(j12)) * Math.exp(-(wx.get(j12) * wx.get(j12)) / 2.0));
                    }
                    v2 = w.copy();
                    meanGwx = this.mean(g_wx);
                    v2.assign(Mult.mult(meanGwx));
                    w1 = v1.copy();
                    w1.assign(v2, PlusMult.plusMult(-1.0));
                    if (i > 0) {
                        DoubleMatrix1D t = w1.copy().assign(0.0);
                        for (u = 0; u < i; ++u) {
                            double k7 = 0.0;
                            for (j2 = 0; j2 < numComponents; ++j2) {
                                k7 += w1.get(j2) * W.get(u, j2);
                            }
                            for (j2 = 0; j2 < numComponents; ++j2) {
                                t.set(j2, t.get(j2) + k7 * W.get(u, j2));
                            }
                        }
                        for (j = 0; j < numComponents; ++j) {
                            w1.set(j, w1.get(j) - t.get(j));
                        }
                    }
                    _rms = this.rms(w1);
                    for (k = 0; k < numComponents; ++k) {
                        w1.set(k, w1.get(k) / _rms);
                    }
                    _tolerance = 0.0;
                    for (k = 0; k < numComponents; ++k) {
                        _tolerance += w1.get(k) * w.get(k);
                    }
                    _tolerance = Math.abs(Math.abs(_tolerance) - 1.0);
                    if (verbose) {
                        System.out.println("Iteration " + it + " tol = " + _tolerance);
                    }
                    w = w1;
                }
            }
            W.viewRow(i).assign(w);
        }
        return W;
    }

    private double mean(DoubleMatrix1D v) {
        return Descriptive.mean(new DoubleArrayList(v.toArray()));
    }

    private double rms(DoubleMatrix1D w) {
        double ssq = Descriptive.sumOfSquares(new DoubleArrayList(w.toArray()));
        return Math.sqrt(ssq);
    }

    private DoubleMatrix2D icaParallel(DoubleMatrix2D X, int numComponents, double tolerance, int function, final double alpha, int maxIterations, boolean verbose, DoubleMatrix2D wInit) {
        DoubleMatrix2D W;
        block18: {
            int it;
            double _tolerance;
            int p;
            Algebra alg;
            block17: {
                alg = new Algebra();
                p = X.columns();
                W = wInit;
                SingularValueDecomposition sW = new SingularValueDecomposition(W);
                DoubleMatrix2D D = sW.getS();
                for (int i = 0; i < D.rows(); ++i) {
                    D.set(i, i, 1.0 / D.get(i, i));
                }
                DoubleMatrix2D WTemp = alg.mult(sW.getU(), D);
                WTemp = alg.mult(WTemp, sW.getU().viewDice());
                W = WTemp = alg.mult(WTemp, W);
                _tolerance = Double.POSITIVE_INFINITY;
                it = 0;
                if (function != LOGCOSH) break block17;
                if (verbose) {
                    System.out.println("Symmetric FastICA using logcosh approx. to neg-entropy function");
                }
                while (_tolerance > tolerance && it < maxIterations) {
                    DoubleMatrix2D wx = alg.mult(W, X);
                    DenseDoubleMatrix2D gwx = new DenseDoubleMatrix2D(numComponents, p);
                    for (int i = 0; i < numComponents; ++i) {
                        for (int j = 0; j < p; ++j) {
                            gwx.set(i, j, Math.tanh(alpha * wx.get(i, j)));
                        }
                    }
                    DoubleMatrix2D v1 = alg.mult((DoubleMatrix2D)gwx, X.viewDice().copy().assign(Mult.div(p)));
                    DoubleMatrix2D g_wx = gwx.copy();
                    g_wx.assign(new DoubleFunction(){

                        @Override
                        public double apply(double v) {
                            return alpha * (1.0 - v * v);
                        }
                    });
                    DenseDoubleMatrix1D V20 = new DenseDoubleMatrix1D(numComponents);
                    for (int k = 0; k < numComponents; ++k) {
                        V20.set(k, this.mean(g_wx.viewRow(k)));
                    }
                    DoubleMatrix2D v2 = DoubleFactory2D.dense.diagonal(V20);
                    v2 = alg.mult(v2, W);
                    DoubleMatrix2D W1 = v1.copy().assign(v2, PlusMult.plusMult(-1.0));
                    SingularValueDecomposition sW1 = new SingularValueDecomposition(W1);
                    DoubleMatrix2D U = sW1.getU();
                    DoubleMatrix2D sD = sW1.getS();
                    for (int i = 0; i < sD.rows(); ++i) {
                        sD.set(i, i, 1.0 / sD.get(i, i));
                    }
                    DoubleMatrix2D W1Temp = alg.mult(U, sD);
                    W1Temp = alg.mult(W1Temp, U.viewDice());
                    W1 = W1Temp = alg.mult(W1Temp, W1);
                    DoubleMatrix2D d1 = alg.mult(W1, W.viewDice());
                    DoubleMatrix1D d = DoubleFactory2D.dense.diagonal(d1);
                    _tolerance = Double.NEGATIVE_INFINITY;
                    for (int i = 0; i < d.size(); ++i) {
                        double m = Math.abs(Math.abs(d.get(i)) - 1.0);
                        if (!(m > _tolerance)) continue;
                        _tolerance = m;
                    }
                    W = W1;
                    if (verbose) {
                        System.out.println("Iteration " + (it + 1) + " tol = " + _tolerance);
                    }
                    ++it;
                }
                break block18;
            }
            if (function != EXP) break block18;
            if (verbose) {
                System.out.println("Symmetric FastICA using exponential approx. to neg-entropy function");
            }
            while (_tolerance > tolerance && it < maxIterations) {
                DoubleMatrix2D wx = alg.mult(W, X);
                DenseDoubleMatrix2D gwx = new DenseDoubleMatrix2D(numComponents, p);
                for (int i = 0; i < numComponents; ++i) {
                    for (int j = 0; j < p; ++j) {
                        double v = wx.get(i, j);
                        gwx.set(i, j, v * Math.exp(-(v * v) / 2.0));
                    }
                }
                DoubleMatrix2D v1 = alg.mult((DoubleMatrix2D)gwx, X.viewDice().copy().assign(Mult.div(p)));
                DoubleMatrix2D g_wx = wx.copy();
                g_wx.assign(new DoubleFunction(){

                    @Override
                    public double apply(double v) {
                        return (1.0 - v * v) * Math.exp(-(v * v) / 2.0);
                    }
                });
                DenseDoubleMatrix1D V20 = new DenseDoubleMatrix1D(numComponents);
                for (int k = 0; k < numComponents; ++k) {
                    V20.set(k, this.mean(g_wx.viewRow(k)));
                }
                DoubleMatrix2D v2 = DoubleFactory2D.dense.diagonal(V20);
                v2 = alg.mult(v2, W);
                DoubleMatrix2D W1 = v1.copy().assign(v2, PlusMult.plusMult(-1.0));
                SingularValueDecomposition sW1 = new SingularValueDecomposition(W1);
                DoubleMatrix2D U = sW1.getU();
                DoubleMatrix2D sD = sW1.getS();
                for (int i = 0; i < sD.rows(); ++i) {
                    sD.set(i, i, 1.0 / sD.get(i, i));
                }
                DoubleMatrix2D W1Temp = alg.mult(U, sD);
                W1Temp = alg.mult(W1Temp, U.viewDice());
                W1 = W1Temp = alg.mult(W1Temp, W1);
                DoubleMatrix2D d1 = alg.mult(W1, W.viewDice());
                DoubleMatrix1D d = DoubleFactory2D.dense.diagonal(d1);
                _tolerance = Double.NEGATIVE_INFINITY;
                for (int i = 0; i < d.size(); ++i) {
                    double m = Math.abs(Math.abs(d.get(i)) - 1.0);
                    if (!(m > _tolerance)) continue;
                    _tolerance = m;
                }
                W.assign(W1);
                if (verbose) {
                    System.out.println("Iteration " + (it + 1) + " tol = " + _tolerance);
                }
                ++it;
            }
        }
        return W;
    }

    private DoubleMatrix2D scale(DoubleMatrix2D x, boolean scale) {
        for (int j = 0; j < x.columns(); ++j) {
            DoubleArrayList u = new DoubleArrayList(x.viewColumn(j).toArray());
            double mean = Descriptive.mean(u);
            for (int i = 0; i < x.rows(); ++i) {
                x.set(i, j, x.get(i, j) - mean);
            }
            if (!scale) continue;
            double rms = this.rms(x.viewColumn(j));
            for (int i = 0; i < x.rows(); ++i) {
                x.set(i, j, x.get(i, j) / rms);
            }
        }
        return x;
    }

    public static class IcaResult {
        private final DoubleMatrix2D X;
        private final DoubleMatrix2D K;
        private final DoubleMatrix2D W;
        private final DoubleMatrix2D S;
        private final DoubleMatrix2D A;

        public IcaResult(DoubleMatrix2D X, DoubleMatrix2D K, DoubleMatrix2D W, DoubleMatrix2D A, DoubleMatrix2D S) {
            this.X = X;
            this.K = K;
            this.W = W;
            this.A = A;
            this.S = S;
        }

        public DoubleMatrix2D getX() {
            return this.X;
        }

        public DoubleMatrix2D getK() {
            return this.K;
        }

        public DoubleMatrix2D getW() {
            return this.W;
        }

        public DoubleMatrix2D getS() {
            return this.S;
        }

        public DoubleMatrix2D getA() {
            return this.A;
        }

        public String toString() {
            StringBuilder buf = new StringBuilder();
            buf.append("\n\nX:\n");
            buf.append(this.X);
            buf.append("\n\nK:\n");
            buf.append(this.K);
            buf.append("\n\nW:\n");
            buf.append(this.W);
            buf.append("\n\nA:\n");
            buf.append(this.A);
            buf.append("\n\nS:\n");
            buf.append(this.S);
            return buf.toString();
        }
    }
}

