/*
 * Decompiled with CFR 0.152.
 */
package edu.cmu.tetrad.search.work_in_progress;

import edu.cmu.tetrad.util.Matrix;
import edu.cmu.tetrad.util.Vector;
import org.apache.commons.math3.util.FastMath;

public class Glasso {
    private final Matrix ss;
    private int n;
    private Rho rho = (i, j) -> 0.0;
    private int maxit = 10000;
    private boolean ia;
    private boolean is;
    private boolean itr;
    private boolean ipen;
    private double thr = 1.0E-4;

    public Glasso(Matrix cov) {
        this.n = cov.getNumRows();
        this.ss = cov;
    }

    public Result search() {
        double dlx;
        int m;
        int j;
        int niter = 0;
        double eps = 1.0E-7;
        int n = this.getN();
        Matrix ss = this.getSs();
        boolean approximateAlgorithm = this.isIa();
        boolean warmStart = this.isIs();
        boolean itr = this.isItr();
        boolean pen = this.isIpen();
        Rho rho = this.getRho();
        Matrix ww = new Matrix(n, n);
        Matrix wwi = new Matrix(n, n);
        int nm1 = n - 1;
        Matrix vv = new Matrix(nm1, nm1);
        Matrix xs = null;
        if (!approximateAlgorithm) {
            xs = new Matrix(nm1, n);
        }
        Vector s = new Vector(nm1);
        Vector x = new Vector(n - 1);
        int[] mm = new int[nm1];
        Vector ro = new Vector(nm1);
        double shr = 0.0;
        for (j = 0; j < n; ++j) {
            for (int k = 0; k < n; ++k) {
                if (j == k) continue;
                shr += FastMath.abs(ss.get(j, k));
            }
        }
        if (shr == 0.0) {
            for (j = 0; j < n && !Thread.currentThread().isInterrupted(); ++j) {
                if (!pen) {
                    ww.set(j, j, ss.get(j, j));
                } else {
                    ww.set(j, j, ss.get(j, j) + rho.get(j, j));
                }
                wwi.set(j, j, 1.0 / FastMath.max(ww.get(j, j), 1.0E-7));
            }
            return new Result(wwi);
        }
        shr = this.getThr() * shr / (double)nm1;
        if (approximateAlgorithm) {
            if (!warmStart) {
                this.zero(wwi);
            }
            for (m = 0; m < n; ++m) {
                int j2;
                System.out.println("m = " + m);
                this.setup(m, n, ss, rho, ss, vv, s, ro);
                int l = -1;
                for (j2 = 0; j2 < n && !Thread.currentThread().isInterrupted(); ++j2) {
                    if (j2 == m) continue;
                    x.set(++l, wwi.get(j2, m));
                }
                this.lasso(ro, nm1, vv, s, shr / (double)n, x, mm);
                l = -1;
                for (j2 = 0; j2 < n && !Thread.currentThread().isInterrupted(); ++j2) {
                    if (j2 == m) continue;
                    wwi.set(j2, m, x.get(++l));
                }
            }
            return new Result(wwi);
        }
        if (!warmStart) {
            ww.assign(ss);
            this.zero(xs);
        } else {
            for (j = 0; j < n; ++j) {
                double xjj = -wwi.get(j, j);
                int l = -1;
                for (int k = 0; k < n && !Thread.currentThread().isInterrupted(); ++k) {
                    if (k == j) continue;
                    xs.set(++l, j, wwi.get(k, j) / xjj);
                }
            }
        }
        for (j = 0; j < n; ++j) {
            if (pen) {
                ww.set(j, j, ss.get(j, j) + rho.get(j, j));
                continue;
            }
            ww.set(j, j, ss.get(j, j));
        }
        do {
            dlx = 0.0;
            for (m = 0; m < n; ++m) {
                if (itr) {
                    System.out.println("Outer loop = " + m);
                }
                x = xs.getColumn(m);
                Vector ws = ww.getColumn(m);
                this.setup(m, n, ss, rho, ww, vv, s, ro);
                Vector so = s.copy();
                this.lasso(ro, nm1, vv, s, shr / this.sum_abs(vv), x, mm);
                int l = -1;
                for (int j3 = 0; j3 < n && !Thread.currentThread().isInterrupted(); ++j3) {
                    if (j3 == m) continue;
                    ww.set(j3, m, so.get(++l) - s.get(l));
                    ww.set(m, j3, ww.get(j3, m));
                }
                dlx = FastMath.max(dlx, this.sum_abs_diff(ww.getColumn(m), ws));
                xs.assignColumn(m, x);
            }
        } while (++niter >= this.getMaxit() && !(dlx < shr));
        this.inv(n, ww, xs, wwi);
        return new Result(wwi);
    }

    private double sum_abs(Matrix m) {
        double sum = 0.0;
        for (int i = 0; i < m.getNumRows(); ++i) {
            for (int j = 0; j < m.getNumColumns(); ++j) {
                sum += FastMath.abs(m.get(i, j));
            }
        }
        return sum;
    }

    private double sum_abs_diff(Vector x, Vector y) {
        double sum = 0.0;
        for (int i = 0; i < x.size(); ++i) {
            sum += FastMath.abs(x.get(i) - y.get(i));
        }
        return sum;
    }

    private void setup(int m, int n, Matrix ss, Rho rho, Matrix ww, Matrix vv, Vector s, Vector r) {
        int l = -1;
        for (int j = 0; j < n; ++j) {
            if (j == m) continue;
            r.set(++l, rho.get(j, m));
            s.set(l, ss.get(j, m));
            int i = -1;
            for (int k = 0; k < n && !Thread.currentThread().isInterrupted(); ++k) {
                if (k == m) continue;
                vv.set(++i, l, ww.get(k, j));
            }
        }
    }

    private void lasso(Vector ro, int n, Matrix vv, Vector s, double thr, Vector x, int[] mm) {
        double dlx;
        this.fatmul(n, vv, x, s, mm);
        do {
            dlx = 0.0;
            for (int j = 0; j < n; ++j) {
                double xj = x.get(j);
                x.set(j, 0.0);
                double t = s.get(j) + vv.get(j, j) * xj;
                if (FastMath.abs(t) - ro.get(j) > 0.0) {
                    x.set(j, FastMath.signum(t) * (FastMath.abs(t) - ro.get(j)) / vv.get(j, j));
                }
                if (x.get(j) == xj) continue;
                double del = x.get(j) - xj;
                dlx = FastMath.max(dlx, FastMath.abs(del));
                for (int i = 0; i < s.size() && !Thread.currentThread().isInterrupted(); ++i) {
                    s.set(i, s.get(i) - del * vv.get(i, j));
                }
            }
        } while (!(dlx < thr));
    }

    private void fatmul(int n, Matrix vv, Vector x, Vector s, int[] m) {
        int j;
        double fac = 0.2;
        int l = 0;
        for (j = 0; j < n; ++j) {
            if (x.get(j) == 0.0) continue;
            m[++l] = j;
        }
        if (l < (int)(0.2 * (double)n)) {
            for (j = 0; j < n && !Thread.currentThread().isInterrupted(); ++j) {
                double dotProduct = 0.0;
                for (int i = 0; i < l; ++i) {
                    dotProduct += vv.get(m[i], j) * x.get(m[j]);
                }
                s.set(j, s.get(j) - dotProduct);
            }
        } else {
            s.assign(vv.times(x).minus(x));
        }
    }

    private void inv(int n, Matrix ww, Matrix xs, Matrix wwi) {
        xs = xs.scalarMult(-1.0);
        int nm1 = n - 1;
        double dp3 = 0.0;
        for (int k = 0; k < n - 1; ++k) {
            dp3 += xs.get(k, 0) * ww.get(k + 1, 0);
        }
        wwi.set(0, 0, 1.0 / (ww.get(0, 0) + dp3));
        for (int i = 1; i < n; ++i) {
            wwi.set(i, 0, wwi.get(0, 0) * xs.get(i - 1, 0));
        }
        double dp4 = 0.0;
        for (int k = 0; k < n - 1; ++k) {
            dp4 += xs.get(k, n - 1) * ww.get(k, n - 1);
        }
        wwi.set(n - 1, n - 1, 1.0 / (ww.get(n - 1, n - 1) + dp4));
        for (int i = 0; i < nm1; ++i) {
            wwi.set(i, n - 1, wwi.get(n - 1, n - 1) * xs.get(i, n - 1));
        }
        for (int j = 1; j < n - 1 && !Thread.currentThread().isInterrupted(); ++j) {
            int p;
            int jm1 = j - 1;
            int jp1 = j + 1;
            double dp1 = 0.0;
            for (int k = 0; k <= jm1; ++k) {
                dp1 += xs.get(k, j) * ww.get(k, j);
            }
            double dp2 = 0.0;
            for (int k = j; k <= n - 2 && !Thread.currentThread().isInterrupted(); ++k) {
                dp2 += xs.get(k, j) * ww.get(k + 1, j);
            }
            wwi.set(j, j, 1.0 / (ww.get(j, j) + dp1 + dp2));
            for (p = 0; p <= jm1; ++p) {
                wwi.set(p, j, wwi.get(j, j) * xs.get(p, j));
            }
            for (p = jp1; p < n; ++p) {
                wwi.set(p, j, wwi.get(j, j) * xs.get(p - 1, j));
            }
        }
    }

    private void zero(Matrix wwi) {
        for (int i = 0; i < wwi.getNumRows(); ++i) {
            for (int j = 0; j < wwi.getNumColumns(); ++j) {
                wwi.set(i, j, 0.0);
            }
        }
    }

    public boolean isIa() {
        return this.ia;
    }

    public void setIa(boolean ia) {
        this.ia = ia;
    }

    public boolean isIs() {
        return this.is;
    }

    public void setIs(boolean is) {
        this.is = is;
    }

    public boolean isItr() {
        return this.itr;
    }

    public void setItr(boolean itr) {
        this.itr = itr;
    }

    public boolean isIpen() {
        return this.ipen;
    }

    public void setIpen(boolean ipen) {
        this.ipen = ipen;
    }

    public double getThr() {
        return this.thr;
    }

    public void setThr(double thr) {
        if (thr < 0.0) {
            throw new IllegalArgumentException("Threshold must be >= 0: " + thr);
        }
        this.thr = thr;
    }

    public int getN() {
        return this.n;
    }

    public void setN(int n) {
        if (n < 0) {
            throw new IllegalArgumentException("Dimension >= 0: " + n);
        }
        this.n = n;
    }

    public Matrix getSs() {
        return this.ss;
    }

    public Rho getRho() {
        return this.rho;
    }

    public void setRhoAllEqual(double rho) {
        this.rho = (i, j) -> rho;
    }

    public int getMaxit() {
        return this.maxit;
    }

    public void setMaxit(int maxit) {
        if (maxit <= 0) {
            throw new IllegalArgumentException("Max iterations must be > 0: " + maxit);
        }
        this.maxit = maxit;
    }

    private static interface Rho {
        public double get(int var1, int var2);
    }

    public static class Result {
        private final Matrix wwi;

        public Result(Matrix wwi) {
            this.wwi = wwi;
        }

        public Matrix getWwi() {
            return this.wwi;
        }
    }
}

