/*
 * Decompiled with CFR 0.152.
 */
package edu.cmu.tetrad.search;

import edu.cmu.tetrad.data.DataSet;
import edu.cmu.tetrad.graph.EdgeListGraph;
import edu.cmu.tetrad.graph.Graph;
import edu.cmu.tetrad.graph.Node;
import edu.cmu.tetrad.search.utils.MeekRules;
import java.util.List;
import org.apache.commons.math3.linear.LUDecomposition;
import org.apache.commons.math3.linear.MatrixUtils;
import org.apache.commons.math3.linear.RealMatrix;
import org.apache.commons.math3.util.FastMath;

public class Dagma {
    private RealMatrix cov;
    private List<Node> variables;
    private RealMatrix I;
    private int d;
    private double lambda1;
    private double wThreshold;
    private boolean cpdag;
    private final double[] T;
    private final double muInit;
    private final double muFactor;
    private final int warmIter;
    private final int maxIter;
    private final double lr;
    private final int checkpoint;
    private final double b1;
    private final double b2;
    private final double tol;

    public Dagma(DataSet dataset) {
        this.variables = dataset.getVariables();
        this.cov = dataset.getCorrelationMatrix().getApacheData();
        this.d = this.cov.getRowDimension();
        this.I = MatrixUtils.createRealIdentityMatrix(this.d);
        this.lambda1 = 0.05;
        this.wThreshold = 0.1;
        this.cpdag = true;
        this.T = new double[]{1.0, 0.9, 0.8, 0.7};
        this.muInit = 1.0;
        this.muFactor = 0.1;
        this.warmIter = 20000;
        this.maxIter = 70000;
        this.lr = 3.0E-4;
        this.checkpoint = 1000;
        this.b1 = 0.99;
        this.b2 = 0.999;
        this.tol = 1.0E-6;
    }

    public Graph search() {
        RealMatrix W = MatrixUtils.createRealMatrix(this.d, this.d);
        double mu = this.muInit;
        int outerIters = this.T.length;
        int innerIters = this.warmIter;
        for (double s : this.T) {
            double lrAdam = this.lr;
            if (outerIters-- == 1) {
                innerIters = this.maxIter;
            }
            while (this.minimize(W, mu, innerIters, s, lrAdam)) {
                lrAdam *= 0.5;
                s += 0.1;
            }
            mu *= this.muFactor;
        }
        return this.toGraph(W);
    }

    public double getLambda1() {
        return this.lambda1;
    }

    public void setLambda1(double lambda1) {
        this.lambda1 = lambda1;
    }

    public double getWThreshold() {
        return this.wThreshold;
    }

    public void setWThreshold(double wThreshold) {
        this.wThreshold = wThreshold;
    }

    public boolean getCpdag() {
        return this.cpdag;
    }

    public void setCpdag(boolean cpdag) {
        this.cpdag = cpdag;
    }

    private double _score(RealMatrix W) {
        RealMatrix dif = this.I.subtract(W);
        RealMatrix rhs = this.cov.multiply(dif);
        return 0.5 * dif.transpose().multiply(rhs).getTrace();
    }

    private double _h(RealMatrix W, double s) {
        RealMatrix M = this.getMMatrix(W, s);
        return (double)this.d * FastMath.log(s) - this.logDet(M);
    }

    private double _func(RealMatrix W, double mu, double s) {
        double score = this._score(W);
        double h = this._h(W, s);
        return mu * (score + this.lambda1 * this.absSum(W)) + h;
    }

    private void adamUpdate(RealMatrix grad, int iter, RealMatrix optM, RealMatrix optV) {
        double b1_ = 1.0 - this.b1;
        double b2_ = 1.0 - this.b2;
        for (int i = 0; i < this.d; ++i) {
            for (int j = 0; j < this.d; ++j) {
                double g = grad.getEntry(i, j);
                double m = optM.getEntry(i, j);
                double v = optV.getEntry(i, j);
                double a = this.b1 * m + b1_ * g;
                double b = this.b2 * v + b2_ * FastMath.pow(g, 2);
                optM.setEntry(i, j, a);
                optV.setEntry(i, j, b);
                grad.setEntry(i, j, (a /= 1.0 - FastMath.pow(this.b1, iter)) / (FastMath.sqrt(b /= 1.0 - FastMath.pow(this.b2, iter)) + 1.0E-8));
            }
        }
    }

    private boolean minimize(RealMatrix W, double mu, int innerIter, double s, double lrAdam) {
        RealMatrix optM = MatrixUtils.createRealMatrix(this.d, this.d);
        RealMatrix optV = MatrixUtils.createRealMatrix(this.d, this.d);
        double objPrev = 1.0E16;
        RealMatrix W_old = W.copy();
        RealMatrix grad = null;
        for (int iter = 1; iter <= innerIter; ++iter) {
            RealMatrix M = MatrixUtils.inverse(this.getMMatrix(W, s));
            this.addToEntries(M, 1.0E-16);
            while (this.notMMatrix(M)) {
                if (iter == 1 || s <= 0.9) {
                    this.setEntries(W, W_old);
                    return true;
                }
                if (lrAdam <= 2.0E-16) {
                    this.addToEntries(W, grad, lrAdam);
                    return false;
                }
                this.addToEntries(W, grad, lrAdam *= 0.5);
                M = MatrixUtils.inverse(this.getMMatrix(W, s));
                this.addToEntries(M, 1.0E-16);
            }
            grad = this.cov.multiply(W);
            for (int i = 0; i < this.d; ++i) {
                for (int j = 0; j < this.d; ++j) {
                    double g = grad.getEntry(i, j);
                    double c = this.cov.getEntry(i, j);
                    double w = W.getEntry(i, j);
                    double mt = M.getEntry(j, i);
                    double sign = 0.0;
                    if (w > 0.0) {
                        sign = 1.0;
                    }
                    if (w < 0.0) {
                        sign = -1.0;
                    }
                    grad.setEntry(i, j, mu * (g - c + this.lambda1 * sign) + 2.0 * w * mt);
                }
            }
            this.adamUpdate(grad, iter, optM, optV);
            this.addToEntries(W, grad, -lrAdam);
            if (iter % this.checkpoint != 0) continue;
            double objNew = this._func(W, mu, s);
            if (FastMath.abs((objPrev - objNew) / objPrev) <= this.tol) break;
            objPrev = objNew;
        }
        return false;
    }

    private RealMatrix getMMatrix(RealMatrix W, double s) {
        RealMatrix M = this.I.scalarMultiply(s);
        for (int i = 0; i < this.d; ++i) {
            for (int j = 0; j < this.d; ++j) {
                M.addToEntry(i, j, -W.getEntry(i, j) * W.getEntry(i, j));
            }
        }
        return M;
    }

    private void setEntries(RealMatrix A, RealMatrix B) {
        for (int i = 0; i < this.d; ++i) {
            for (int j = 0; j < this.d; ++j) {
                A.setEntry(i, j, B.getEntry(i, j));
            }
        }
    }

    private void addToEntries(RealMatrix A, double c) {
        for (int i = 0; i < this.d; ++i) {
            for (int j = 0; j < this.d; ++j) {
                A.addToEntry(i, j, c);
            }
        }
    }

    private void addToEntries(RealMatrix A, RealMatrix B, double c) {
        for (int i = 0; i < this.d; ++i) {
            for (int j = 0; j < this.d; ++j) {
                A.addToEntry(i, j, c * B.getEntry(i, j));
            }
        }
    }

    private double logDet(RealMatrix M) {
        assert (M.isSquare());
        int d = M.getRowDimension();
        LUDecomposition lud = new LUDecomposition(M);
        RealMatrix P = lud.getP();
        RealMatrix L = lud.getL();
        RealMatrix U = lud.getU();
        double logDet = FastMath.log(FastMath.abs((double)d - P.getTrace() - 1.0));
        for (int i = 0; i < d; ++i) {
            logDet += FastMath.log(FastMath.abs(L.getEntry(i, i)));
            logDet += FastMath.log(FastMath.abs(U.getEntry(i, i)));
        }
        return logDet;
    }

    private boolean notMMatrix(RealMatrix M) {
        assert (M.isSquare());
        int d = M.getRowDimension();
        for (int i = 0; i < d; ++i) {
            for (int j = 0; j < d; ++j) {
                if (!(M.getEntry(i, j) < 0.0)) continue;
                return true;
            }
        }
        return false;
    }

    private double absSum(RealMatrix M) {
        assert (M.isSquare());
        int d = M.getRowDimension();
        double s = 0.0;
        for (int i = 0; i < d; ++i) {
            for (int j = 0; j < d; ++j) {
                s += FastMath.abs(M.getEntry(i, j));
            }
        }
        return s;
    }

    private Graph toGraph(RealMatrix W) {
        RealMatrix W_ = MatrixUtils.createRealMatrix(this.d, this.d);
        for (int i = 0; i < this.d; ++i) {
            for (int j = 0; j < this.d; ++j) {
                W_.setEntry(i, j, FastMath.abs(W.getEntry(i, j)));
            }
        }
        double wThreshold = this.wThreshold;
        do {
            double wMin = Double.MAX_VALUE;
            for (int i = 0; i < this.d; ++i) {
                for (int j = 0; j < this.d; ++j) {
                    double w_ = W_.getEntry(i, j);
                    if (w_ < wThreshold) {
                        W_.setEntry(i, j, 0.0);
                        continue;
                    }
                    if (!(w_ < wMin)) continue;
                    wMin = w_;
                }
            }
            wThreshold = wMin + 1.0E-6;
        } while (W_.power(this.d).getTrace() > 0.0);
        EdgeListGraph graph = new EdgeListGraph(this.variables);
        for (int i = 0; i < this.d; ++i) {
            for (int j = 0; j < this.d; ++j) {
                if (W_.getEntry(i, j) == 0.0) continue;
                graph.addDirectedEdge(this.variables.get(i), this.variables.get(j));
            }
        }
        if (this.cpdag) {
            MeekRules rules = new MeekRules();
            rules.orientImplied(graph);
        }
        return graph;
    }
}

