/*
 * Decompiled with CFR 0.152.
 */
package edu.cmu.tetrad.search;

import edu.cmu.tetrad.util.Matrix;
import edu.cmu.tetrad.util.RandomUtil;
import edu.cmu.tetrad.util.TetradLogger;
import edu.cmu.tetrad.util.Vector;
import org.apache.commons.math3.linear.SingularValueDecomposition;
import org.apache.commons.math3.util.FastMath;

public class FastIca {
    public static int PARALLEL;
    public static int DEFLATION;
    public static int LOGCOSH;
    public static int EXP;
    private final Matrix X;
    private int numComponents;
    private int algorithmType = PARALLEL;
    private int function = LOGCOSH;
    private double alpha = 1.1;
    private boolean rowNorm;
    private int maxIterations = 200;
    private double tolerance = 1.0E-4;
    private boolean verbose;
    private Matrix wInit;

    public FastIca(Matrix X, int numComponents) {
        this.X = X;
        this.numComponents = numComponents;
    }

    public void setAlgorithmType(int algorithmType) {
        if (algorithmType != DEFLATION && algorithmType != PARALLEL) {
            throw new IllegalArgumentException("Value should be DEFLATION or PARALLEL.");
        }
        this.algorithmType = algorithmType;
    }

    public void setFunction(int function) {
        if (function != LOGCOSH && function != EXP) {
            throw new IllegalArgumentException("Value should be LOGCOSH or EXP.");
        }
        this.function = function;
    }

    public void setAlpha(double alpha) {
        if (!(alpha >= 1.0) || !(alpha <= 2.0)) {
            throw new IllegalArgumentException("Alpha should be in range [1, 2].");
        }
        this.alpha = alpha;
    }

    public void setRowNorm(boolean rowNorm) {
        this.rowNorm = rowNorm;
    }

    public void setMaxIterations(int maxIterations) {
        if (maxIterations < 1) {
            TetradLogger.getInstance().log("info", "maxIterations should be positive.");
        }
        this.maxIterations = maxIterations;
    }

    public void setTolerance(double tolerance) {
        if (!(tolerance > 0.0)) {
            TetradLogger.getInstance().log("info", "Tolerance should be positive.");
        }
        this.tolerance = tolerance;
    }

    public void setVerbose(boolean verbose) {
        this.verbose = verbose;
    }

    public void setWInit(Matrix wInit) {
        this.wInit = wInit;
    }

    public IcaResult findComponents() {
        Matrix b;
        int p;
        int n = this.X.getNumColumns();
        if (this.numComponents > FastMath.min(n, p = this.X.getNumRows())) {
            TetradLogger.getInstance().log("info", "Requested number of components is too large.");
            TetradLogger.getInstance().log("info", "Reset to " + FastMath.min(n, p));
            this.numComponents = FastMath.min(n, p);
        }
        if (this.wInit == null) {
            this.wInit = new Matrix(this.numComponents, this.numComponents);
            for (int i = 0; i < this.wInit.getNumRows(); ++i) {
                for (int j = 0; j < this.wInit.getNumColumns(); ++j) {
                    this.wInit.set(i, j, RandomUtil.getInstance().nextNormal(0.0, 1.0));
                }
            }
        } else if (this.wInit.getNumRows() != this.wInit.getNumColumns()) {
            throw new IllegalArgumentException("wInit is the wrong size.");
        }
        if (this.verbose) {
            TetradLogger.getInstance().log("info", "Centering");
        }
        this.center(this.X);
        if (this.rowNorm) {
            this.scale(this.X);
        }
        if (this.verbose) {
            TetradLogger.getInstance().log("info", "Whitening");
        }
        Matrix cov = this.X.times(this.X.transpose()).scalarMult(1.0 / (double)n);
        SingularValueDecomposition s = new SingularValueDecomposition(cov.getApacheData());
        Matrix D = new Matrix(s.getS().getData());
        Matrix U = new Matrix(s.getU().getData());
        for (int i = 0; i < D.getNumRows(); ++i) {
            D.set(i, i, 1.0 / FastMath.sqrt(D.get(i, i)));
        }
        Matrix K = D.times(U.transpose());
        K = K.getPart(0, this.numComponents - 1, 0, p - 1);
        Matrix X1 = K.times(this.X);
        if (this.algorithmType == DEFLATION) {
            b = this.icaDeflation(X1, this.tolerance, this.function, this.alpha, this.maxIterations, this.verbose, this.wInit);
        } else if (this.algorithmType == PARALLEL) {
            b = this.icaParallel(X1, this.numComponents, this.tolerance, this.alpha, this.maxIterations, this.verbose, this.wInit);
        } else {
            throw new IllegalStateException();
        }
        Matrix w = b.times(K);
        Matrix S = w.times(this.X);
        return new IcaResult(this.X, K, w, S);
    }

    private Matrix icaDeflation(Matrix X, double tolerance, int function, double alpha, int maxIterations, boolean verbose, Matrix wInit) {
        if (verbose && function == LOGCOSH) {
            TetradLogger.getInstance().log("info", "Deflation FastIca using lgcosh approx. to neg-entropy function");
        }
        if (verbose && function == EXP) {
            TetradLogger.getInstance().log("info", "Deflation FastIca using exponential approx. to neg-entropy function");
        }
        Matrix W = new Matrix(X.getNumRows(), X.getNumRows());
        for (int i = 0; i < X.getNumRows(); ++i) {
            if (verbose) {
                TetradLogger.getInstance().log("fastIcaDetails", "Component " + (i + 1));
            }
            Vector w = wInit.getRow(i);
            if (i > 0) {
                for (int u = 0; u < i; ++u) {
                    double k = w.dotProduct(W.getRow(u));
                    w = w.minus(W.getRow(u).scalarMult(k));
                }
            }
            w = w.scalarMult(1.0 / this.rms(w));
            int it = 0;
            double _tolerance = Double.POSITIVE_INFINITY;
            while (_tolerance > tolerance && ++it <= maxIterations) {
                Vector wx = X.transpose().times(w);
                Vector gwx0 = new Vector(X.getNumColumns());
                for (int j = 0; j < X.getNumColumns(); ++j) {
                    gwx0.set(j, this.g(alpha, wx.get(j)));
                }
                Matrix gwx = new Matrix(X.getNumRows(), X.getNumColumns());
                for (int _i = 0; _i < X.getNumRows(); ++_i) {
                    gwx.assignRow(i, gwx0);
                }
                Matrix xgwx = new Matrix(X.getNumRows(), X.getNumColumns());
                for (int _i = 0; _i < X.getNumRows(); ++_i) {
                    for (int j = 0; j < X.getNumColumns(); ++j) {
                        xgwx.set(_i, j, X.get(_i, j) * gwx0.get(j));
                    }
                }
                Vector v1 = new Vector(X.getNumRows());
                for (int k = 0; k < X.getNumRows(); ++k) {
                    v1.set(k, this.mean(xgwx.getRow(k)));
                }
                Vector g_wx = new Vector(X.getNumColumns());
                for (int k = 0; k < X.getNumColumns(); ++k) {
                    double t = this.g(alpha, wx.get(k));
                    g_wx.set(k, 1.0 - t * t);
                }
                Vector v2 = w.copy();
                double meanGwx = this.mean(g_wx);
                v2 = v2.scalarMult(meanGwx);
                Vector w1 = v1.minus(v2);
                if (i > 0) {
                    Vector t = w1.like();
                    for (int u = 0; u < i; ++u) {
                        int j;
                        double k = 0.0;
                        for (j = 0; j < X.getNumRows(); ++j) {
                            k += w1.get(j) * W.get(u, j);
                        }
                        for (j = 0; j < X.getNumRows(); ++j) {
                            t.set(j, t.get(j) + k * W.get(u, j));
                        }
                    }
                    for (int j = 0; j < X.getNumRows(); ++j) {
                        w1.set(j, w1.get(j) - t.get(j));
                    }
                }
                w1 = w1.scalarMult(1.0 / this.rms(w1));
                _tolerance = 0.0;
                for (int k = 0; k < X.getNumRows(); ++k) {
                    _tolerance += w1.get(k) * w.get(k);
                }
                _tolerance = FastMath.abs(FastMath.abs(_tolerance) - 1.0);
                if (verbose) {
                    TetradLogger.getInstance().log("fastIcaDetails", "Iteration " + it + " tol = " + _tolerance);
                }
                w = w1;
            }
            W.assignRow(i, w);
        }
        return W;
    }

    private double g(double alpha, double y) {
        if (this.function == LOGCOSH) {
            return FastMath.tanh(alpha * y);
        }
        if (this.function == EXP) {
            return y * FastMath.exp(-(y * y) / 2.0);
        }
        throw new IllegalArgumentException("That function is not configured.");
    }

    private double mean(Vector v) {
        double sum = 0.0;
        for (int i = 0; i < v.size(); ++i) {
            sum += v.get(i);
        }
        return sum / (double)v.size();
    }

    private double sumOfSquares(Vector v) {
        double sum = 0.0;
        for (int i = 0; i < v.size(); ++i) {
            sum += v.get(i) * v.get(i);
        }
        return sum;
    }

    private double rms(Vector w) {
        double ssq = this.sumOfSquares(w);
        return FastMath.sqrt(ssq);
    }

    private Matrix icaParallel(Matrix X, int numComponents, double tolerance, double alpha, int maxIterations, boolean verbose, Matrix wInit) {
        int p = X.getNumColumns();
        Matrix W = wInit;
        SingularValueDecomposition sW = new SingularValueDecomposition(W.getApacheData());
        Matrix D = new Matrix(sW.getS().getData());
        for (int i = 0; i < D.getNumRows(); ++i) {
            D.set(i, i, 1.0 / D.get(i, i));
        }
        Matrix WTemp = new Matrix(sW.getU()).times(D);
        WTemp = WTemp.times(new Matrix(sW.getU()).transpose());
        W = WTemp = WTemp.times(W);
        double _tolerance = Double.POSITIVE_INFINITY;
        int it = 0;
        if (verbose) {
            TetradLogger.getInstance().log("info", "Symmetric FastICA using logcosh approx. to neg-entropy function");
        }
        while (_tolerance > tolerance && it < maxIterations) {
            Matrix wx = W.times(X);
            Matrix gwx = new Matrix(numComponents, p);
            for (int i = 0; i < numComponents; ++i) {
                for (int j = 0; j < p; ++j) {
                    gwx.set(i, j, this.g(alpha, wx.get(i, j)));
                }
            }
            Matrix v1 = gwx.times(X.transpose().scalarMult(1.0 / (double)p));
            Matrix g_wx = gwx.like();
            for (int i = 0; i < g_wx.getNumRows(); ++i) {
                for (int j = 0; j < g_wx.getNumColumns(); ++j) {
                    double v = g_wx.get(i, j);
                    double w = alpha * (1.0 - v * v);
                    g_wx.set(i, j, w);
                }
            }
            Vector V20 = new Vector(numComponents);
            for (int k = 0; k < numComponents; ++k) {
                V20.set(k, this.mean(g_wx.getRow(k)));
            }
            Matrix v2 = V20.diag();
            v2 = v2.times(W);
            Matrix W1 = v1.minus(v2);
            SingularValueDecomposition sW1 = new SingularValueDecomposition(W1.getApacheData());
            Matrix U = new Matrix(sW1.getU());
            Matrix sD = new Matrix(sW1.getS());
            for (int i = 0; i < sD.getNumRows(); ++i) {
                sD.set(i, i, 1.0 / sD.get(i, i));
            }
            Matrix W1Temp = U.times(sD);
            W1Temp = W1Temp.times(U.transpose());
            W1 = W1Temp = W1Temp.times(W1);
            Matrix d1 = W1.times(W.transpose());
            Vector d = d1.diag();
            _tolerance = Double.NEGATIVE_INFINITY;
            for (int i = 0; i < d.size(); ++i) {
                double m = FastMath.abs(FastMath.abs(d.get(i)) - 1.0);
                if (!(m > _tolerance)) continue;
                _tolerance = m;
            }
            W = W1;
            if (verbose) {
                TetradLogger.getInstance().log("fastIcaDetails", "Iteration " + (it + 1) + " tol = " + _tolerance);
            }
            ++it;
        }
        return W;
    }

    private void scale(Matrix x) {
        for (int i = 0; i < x.getNumRows(); ++i) {
            Vector u = x.getRow(i).scalarMult(1.0 / this.rms(x.getRow(i)));
            x.assignRow(i, u);
        }
    }

    private void center(Matrix x) {
        for (int i = 0; i < x.getNumRows(); ++i) {
            Vector u = x.getRow(i);
            double mean = this.mean(u);
            for (int j = 0; j < x.getNumColumns(); ++j) {
                x.set(i, j, x.get(i, j) - mean);
            }
        }
    }

    static {
        DEFLATION = 1;
        LOGCOSH = 2;
        EXP = 3;
    }

    public static class IcaResult {
        private final Matrix X;
        private final Matrix K;
        private final Matrix W;
        private final Matrix S;

        public IcaResult(Matrix X, Matrix K, Matrix W, Matrix S) {
            this.X = X;
            this.K = K;
            this.W = W;
            this.S = S;
        }

        public Matrix getX() {
            return this.X;
        }

        public Matrix getK() {
            return this.K;
        }

        public Matrix getW() {
            return this.W;
        }

        public Matrix getS() {
            return this.S;
        }

        public String toString() {
            return "\n\nX:\n" + this.X + "\n\nK:\n" + this.K + "\n\nW:\n" + this.W + "\n\nS:\n" + this.S;
        }
    }
}

