/*
 * Decompiled with CFR 0.152.
 */
package edu.pitt.csb.mgm;

import cern.colt.matrix.DoubleFactory1D;
import cern.colt.matrix.DoubleFactory2D;
import cern.colt.matrix.DoubleMatrix1D;
import cern.colt.matrix.DoubleMatrix2D;
import cern.colt.matrix.linalg.Algebra;
import cern.jet.math.Functions;
import edu.cmu.tetrad.data.ContinuousVariable;
import edu.cmu.tetrad.data.DataSet;
import edu.cmu.tetrad.data.DiscreteVariable;
import edu.cmu.tetrad.graph.EdgeListGraph;
import edu.cmu.tetrad.graph.Graph;
import edu.cmu.tetrad.graph.GraphUtils;
import edu.cmu.tetrad.graph.Node;
import edu.cmu.tetrad.search.IGraphSearch;
import edu.cmu.tetrad.sem.GeneralizedSemIm;
import edu.cmu.tetrad.sem.GeneralizedSemPm;
import edu.cmu.tetrad.util.MillisecondTimes;
import edu.cmu.tetrad.util.StatUtils;
import edu.pitt.csb.mgm.ConvexProximal;
import edu.pitt.csb.mgm.MixedUtils;
import edu.pitt.csb.mgm.ProximalGradient;
import java.io.IOException;
import java.util.ArrayList;
import java.util.Arrays;
import java.util.HashMap;
import java.util.List;
import org.apache.commons.math3.util.FastMath;

public class Mgm
extends ConvexProximal
implements IGraphSearch {
    private final DoubleFactory2D factory2D = DoubleFactory2D.dense;
    private final DoubleFactory1D factory1D = DoubleFactory1D.dense;
    private final DoubleMatrix2D xDat;
    private final DoubleMatrix2D yDat;
    private final DoubleMatrix1D lambda;
    private final Algebra alg = new Algebra();
    private final int[] l;
    int p;
    int q;
    int n;
    private List<Node> variables;
    private List<Node> initVariables;
    private DoubleMatrix2D dDat;
    private long elapsedTime;
    private int lsum;
    private int[] lcumsum;
    private DoubleMatrix1D weights;
    private MGMParams params;

    public Mgm(DoubleMatrix2D x, DoubleMatrix2D y, List<Node> variables, int[] l, double[] lambda) {
        if (l.length != y.columns()) {
            throw new IllegalArgumentException("length of l doesn't match number of variables in Y");
        }
        if (y.rows() != x.rows()) {
            throw new IllegalArgumentException("different number of samples for x and y");
        }
        if (lambda.length != 3) {
            throw new IllegalArgumentException("Lambda should have three values for cc, cd, and dd edges respectively");
        }
        this.xDat = x;
        this.yDat = y;
        this.l = l;
        this.p = x.columns();
        this.q = y.columns();
        this.n = x.rows();
        this.variables = variables;
        this.lambda = this.factory1D.make(lambda);
        this.fixData();
        this.initParameters();
        this.calcWeights();
        this.makeDummy();
    }

    public Mgm(DataSet ds, double[] lambda) {
        this.variables = ds.getVariables();
        boolean hasContinuous = false;
        boolean hasDiscrete = false;
        for (Node node : this.variables) {
            if (node instanceof ContinuousVariable) {
                hasContinuous = true;
            }
            if (!(node instanceof DiscreteVariable)) continue;
            hasDiscrete = true;
        }
        if (!hasContinuous || !hasDiscrete) {
            throw new IllegalArgumentException("Please give data with at least one discrete and one continuous variable to run MGM.");
        }
        DataSet dsCont = MixedUtils.getContinousData(ds);
        DataSet dsDisc = MixedUtils.getDiscreteData(ds);
        this.xDat = this.factory2D.make(dsCont.getDoubleData().toArray());
        this.yDat = this.factory2D.make(dsDisc.getDoubleData().toArray());
        this.l = MixedUtils.getDiscLevels(ds);
        this.p = this.xDat.columns();
        this.q = this.yDat.columns();
        this.n = this.xDat.rows();
        this.variables = new ArrayList<Node>();
        this.variables.addAll(dsCont.getVariables());
        this.variables.addAll(dsDisc.getVariables());
        this.initVariables = ds.getVariables();
        this.lambda = this.factory1D.make(lambda);
        this.fixData();
        this.initParameters();
        this.calcWeights();
        this.makeDummy();
    }

    public static DoubleMatrix1D flatten(DoubleMatrix2D m) {
        DoubleMatrix1D[] colArray = new DoubleMatrix1D[m.columns()];
        for (int i = 0; i < m.columns(); ++i) {
            colArray[i] = m.viewColumn(i);
        }
        return DoubleFactory1D.dense.make(colArray);
    }

    private static DoubleMatrix1D margSum(DoubleMatrix2D mat, int marg) {
        DoubleMatrix1D vec;
        block4: {
            DoubleFactory1D fac;
            int n;
            block3: {
                n = 0;
                vec = null;
                fac = DoubleFactory1D.dense;
                if (marg != 1) break block3;
                n = mat.columns();
                vec = fac.make(n);
                for (int j = 0; j < mat.rows() && !Thread.currentThread().isInterrupted(); ++j) {
                    for (int i = 0; i < n; ++i) {
                        vec.setQuick(i, vec.getQuick(i) + mat.getQuick(j, i));
                    }
                }
                break block4;
            }
            if (marg != 2) break block4;
            n = mat.rows();
            vec = fac.make(n);
            for (int i = 0; i < n && !Thread.currentThread().isInterrupted(); ++i) {
                vec.setQuick(i, mat.viewRow(i).zSum());
            }
        }
        return vec;
    }

    public static DoubleMatrix2D upperTri(DoubleMatrix2D mat, int di) {
        for (int i = FastMath.max(-di + 1, 0); i < mat.rows() && !Thread.currentThread().isInterrupted(); ++i) {
            for (int j = 0; j < FastMath.min(i + di, mat.rows()) && !Thread.currentThread().isInterrupted(); ++j) {
                mat.set(i, j, 0.0);
            }
        }
        return mat;
    }

    private static DoubleMatrix2D lowerTri(DoubleMatrix2D mat, int di) {
        for (int i = 0; i < mat.rows() - FastMath.max(di + 1, 0) && !Thread.currentThread().isInterrupted(); ++i) {
            for (int j = FastMath.max(i + di + 1, 0); j < mat.rows() && !Thread.currentThread().isInterrupted(); ++j) {
                mat.set(i, j, 0.0);
            }
        }
        return mat;
    }

    private static double norm2(DoubleMatrix2D mat) {
        Algebra al = new Algebra();
        if (mat.rows() < mat.columns()) {
            return al.norm2(al.transpose(mat));
        }
        return al.norm2(mat);
    }

    private static double norm2(DoubleMatrix1D vec) {
        return FastMath.sqrt(new Algebra().norm2(vec));
    }

    private static void runTests1() {
        try {
            String path = "/Users/ajsedgewick/tetrad_master/tetrad/tetrad-lib/src/main/java/edu/pitt/csb/mgm/test_data";
            System.out.println("/Users/ajsedgewick/tetrad_master/tetrad/tetrad-lib/src/main/java/edu/pitt/csb/mgm/test_data");
            DoubleMatrix2D xIn = DoubleFactory2D.dense.make(MixedUtils.loadDelim("/Users/ajsedgewick/tetrad_master/tetrad/tetrad-lib/src/main/java/edu/pitt/csb/mgm/test_data", "med_test_C.txt").getDoubleData().toArray());
            DoubleMatrix2D yIn = DoubleFactory2D.dense.make(MixedUtils.loadDelim("/Users/ajsedgewick/tetrad_master/tetrad/tetrad-lib/src/main/java/edu/pitt/csb/mgm/test_data", "med_test_D.txt").getDoubleData().toArray());
            int[] L = new int[24];
            Node[] vars = new Node[48];
            for (int i = 0; i < 24; ++i) {
                L[i] = 2;
                vars[i] = new ContinuousVariable("X" + i);
                vars[i + 24] = new DiscreteVariable("Y" + i);
            }
            double lam = 0.2;
            Mgm model = new Mgm(xIn, yIn, new ArrayList<Node>(Arrays.asList(vars)), L, new double[]{0.2, 0.2, 0.2});
            Mgm model2 = new Mgm(xIn, yIn, new ArrayList<Node>(Arrays.asList(vars)), L, new double[]{0.2, 0.2, 0.2});
            System.out.println("Weights: " + Arrays.toString(model.weights.toArray()));
            DoubleMatrix2D test = xIn.copy();
            DoubleMatrix2D test2 = xIn.copy();
            long t = MillisecondTimes.timeMillis();
            for (int i = 0; i < 50000; ++i) {
                test2 = xIn.copy();
                test.assign(test2);
            }
            System.out.println("assign Time: " + (MillisecondTimes.timeMillis() - t));
            t = MillisecondTimes.timeMillis();
            double[][] xArr = xIn.toArray();
            for (int i = 0; i < 50000 && !Thread.currentThread().isInterrupted(); ++i) {
                test = test2 = xIn.copy();
            }
            System.out.println("equals Time: " + (MillisecondTimes.timeMillis() - t));
            System.out.println("Init nll: " + model.smoothValue(model.params.toMatrix1D()));
            System.out.println("Init reg term: " + model.nonSmoothValue(model.params.toMatrix1D()));
            t = MillisecondTimes.timeMillis();
            model.learnEdges(700);
            System.out.println("Orig Time: " + (MillisecondTimes.timeMillis() - t));
            System.out.println("nll: " + model.smoothValue(model.params.toMatrix1D()));
            System.out.println("reg term: " + model.nonSmoothValue(model.params.toMatrix1D()));
            System.out.println("params:\n" + model.params);
            System.out.println("adjMat:\n" + model.adjMatFromMGM());
        }
        catch (IOException ex) {
            ex.printStackTrace();
        }
    }

    private static void runTests2() {
        Graph g = GraphUtils.convert("X1-->X2,X3-->X2,X4-->X5");
        HashMap<String, Integer> nd = new HashMap<String, Integer>();
        nd.put("X1", 0);
        nd.put("X2", 0);
        nd.put("X3", 4);
        nd.put("X4", 4);
        nd.put("X5", 4);
        g = MixedUtils.makeMixedGraph(g, nd);
        GeneralizedSemPm pm = MixedUtils.GaussianCategoricalPm(g, "Split(-1.5,-.5,.5,1.5)");
        System.out.println(pm);
        GeneralizedSemIm im = MixedUtils.GaussianCategoricalIm(pm);
        System.out.println(im);
        int samps = 1000;
        DataSet ds = im.simulateDataFisher(1000);
        ds = MixedUtils.makeMixedData(ds, nd);
        double lambda = 0.0;
        Mgm model = new Mgm(ds, new double[]{0.0, 0.0, 0.0});
        System.out.println("Init nll: " + model.smoothValue(model.params.toMatrix1D()));
        System.out.println("Init reg term: " + model.nonSmoothValue(model.params.toMatrix1D()));
        model.learn(1.0E-8, 1000);
        System.out.println("Learned nll: " + model.smoothValue(model.params.toMatrix1D()));
        System.out.println("Learned reg term: " + model.nonSmoothValue(model.params.toMatrix1D()));
        System.out.println("params:\n" + model.params);
        System.out.println("adjMat:\n" + model.adjMatFromMGM());
    }

    public static void main(String[] args) {
        Mgm.runTests1();
    }

    public void setParams(MGMParams newParams) {
        this.params = newParams;
    }

    private void initParameters() {
        this.lcumsum = new int[this.l.length + 1];
        this.lcumsum[0] = 0;
        for (int i = 0; i < this.l.length; ++i) {
            this.lcumsum[i + 1] = this.lcumsum[i] + this.l[i];
        }
        this.lsum = this.lcumsum[this.l.length];
        DoubleMatrix2D beta = this.factory2D.make(this.xDat.columns(), this.xDat.columns());
        DoubleMatrix1D betad = this.factory1D.make(this.xDat.columns(), 1.0);
        DoubleMatrix2D theta = this.factory2D.make(this.lsum, this.xDat.columns());
        DoubleMatrix2D phi = this.factory2D.make(this.lsum, this.lsum);
        DoubleMatrix1D alpha1 = this.factory1D.make(this.xDat.columns());
        DoubleMatrix1D alpha2 = this.factory1D.make(this.lsum);
        this.params = new MGMParams(beta, betad, theta, phi, alpha1, alpha2);
    }

    private double logsumexp(DoubleMatrix1D x) {
        DoubleMatrix1D myX = x.copy();
        double maxX = StatUtils.max(myX.toArray());
        return FastMath.log(myX.assign(Functions.minus(maxX)).assign(Functions.exp).zSum()) + maxX;
    }

    private void calcWeights() {
        this.weights = this.factory1D.make(this.p + this.q);
        for (int i = 0; i < this.p; ++i) {
            this.weights.set(i, StatUtils.sd(this.xDat.viewColumn(i).toArray()));
        }
        for (int j = 0; j < this.q; ++j) {
            double curWeight = 0.0;
            for (int k = 0; k < this.l[j]; ++k) {
                double curp = this.yDat.viewColumn(j).copy().assign(Functions.equals(k + 1)).zSum() / (double)this.n;
                curWeight += curp * (1.0 - curp);
            }
            this.weights.set(this.p + j, FastMath.sqrt(curWeight));
        }
    }

    private void makeDummy() {
        this.dDat = this.factory2D.make(this.n, this.lsum);
        for (int i = 0; i < this.q; ++i) {
            for (int j = 0; j < this.l[i]; ++j) {
                DoubleMatrix1D curCol = this.yDat.viewColumn(i).copy().assign(Functions.equals(j + 1));
                if (curCol.zSum() == 0.0) {
                    throw new IllegalArgumentException("Discrete data is missing a level: variable " + i + " level " + j);
                }
                this.dDat.viewColumn(this.lcumsum[i] + j).assign(curCol);
            }
        }
    }

    private void fixData() {
        double ymin = StatUtils.min(Mgm.flatten(this.yDat).toArray());
        if (ymin < 0.0 || ymin > 1.0) {
            throw new IllegalArgumentException("Discrete data must be either zero or one indexed. Found min index: " + ymin);
        }
        if (ymin == 0.0) {
            this.yDat.assign(Functions.plus(1.0));
        }
        for (int i = 0; i < this.p; ++i) {
            this.xDat.viewColumn(i).assign(StatUtils.standardizeData(this.xDat.viewColumn(i).toArray()));
        }
    }

    @Override
    public double smoothValue(DoubleMatrix1D parIn) {
        int i;
        MGMParams par = new MGMParams(parIn, this.p, this.lsum);
        for (i = 0; i < par.betad.size(); ++i) {
            if (!(par.betad.get(i) < 0.0)) continue;
            return Double.POSITIVE_INFINITY;
        }
        Mgm.upperTri(par.beta, 1);
        par.beta.assign(this.alg.transpose(par.beta), Functions.plus);
        for (i = 0; i < this.q; ++i) {
            par.phi.viewPart(this.lcumsum[i], this.lcumsum[i], this.l[i], this.l[i]).assign(0.0);
        }
        Mgm.upperTri(par.phi, 0);
        par.phi.assign(this.alg.transpose(par.phi), Functions.plus);
        DoubleMatrix2D divBetaD = this.factory2D.diagonal(this.factory1D.make(this.p, 1.0).assign(par.betad, Functions.div));
        DoubleMatrix2D xBeta = this.alg.mult(this.xDat, this.alg.mult(par.beta, divBetaD));
        DoubleMatrix2D dTheta = this.alg.mult(this.alg.mult(this.dDat, par.theta), divBetaD);
        DoubleMatrix2D tempLoss = this.factory2D.make(this.n, this.xDat.columns());
        DoubleMatrix2D wxProd = this.alg.mult(this.xDat, this.alg.transpose(par.theta));
        wxProd.assign(this.alg.mult(this.dDat, par.phi), Functions.plus);
        for (int i2 = 0; i2 < this.n; ++i2) {
            int j;
            for (j = 0; j < this.xDat.columns(); ++j) {
                tempLoss.set(i2, j, this.xDat.get(i2, j) - par.alpha1.get(j) - xBeta.get(i2, j) - dTheta.get(i2, j));
            }
            for (j = 0; j < this.dDat.columns(); ++j) {
                wxProd.set(i2, j, wxProd.get(i2, j) + par.alpha2.get(j));
            }
        }
        double sqloss = (double)(-this.n) / 2.0 * par.betad.copy().assign(Functions.log).zSum() + 0.5 * FastMath.pow(this.alg.normF(this.alg.mult(tempLoss, this.factory2D.diagonal(par.betad.copy().assign(Functions.sqrt)))), 2);
        double catloss = 0.0;
        for (int i3 = 0; i3 < this.yDat.columns(); ++i3) {
            DoubleMatrix2D wxTemp = wxProd.viewPart(0, this.lcumsum[i3], this.n, this.l[i3]);
            for (int k = 0; k < this.n; ++k) {
                DoubleMatrix1D curRow = wxTemp.viewRow(k);
                catloss -= curRow.get((int)this.yDat.get(k, i3) - 1);
                catloss += this.logsumexp(curRow);
            }
        }
        return (sqloss + catloss) / (double)this.n;
    }

    @Override
    public double smooth(DoubleMatrix1D parIn, DoubleMatrix1D gradOutVec) {
        int i;
        MGMParams par = new MGMParams(parIn, this.p, this.lsum);
        MGMParams gradOut = new MGMParams();
        for (i = 0; i < par.betad.size(); ++i) {
            if (!(par.betad.get(i) < 0.0)) continue;
            return Double.POSITIVE_INFINITY;
        }
        Mgm.upperTri(par.beta, 1);
        par.beta.assign(this.alg.transpose(par.beta), Functions.plus);
        for (i = 0; i < this.q; ++i) {
            par.phi.viewPart(this.lcumsum[i], this.lcumsum[i], this.l[i], this.l[i]).assign(0.0);
        }
        Mgm.upperTri(par.phi, 0);
        par.phi.assign(this.alg.transpose(par.phi), Functions.plus);
        DoubleMatrix2D divBetaD = this.factory2D.diagonal(this.factory1D.make(this.p, 1.0).assign(par.betad, Functions.div));
        DoubleMatrix2D xBeta = this.alg.mult(this.xDat, this.alg.mult(par.beta, divBetaD));
        DoubleMatrix2D dTheta = this.alg.mult(this.alg.mult(this.dDat, par.theta), divBetaD);
        DoubleMatrix2D tempLoss = this.factory2D.make(this.n, this.xDat.columns());
        DoubleMatrix2D wxProd = this.alg.mult(this.xDat, this.alg.transpose(par.theta));
        wxProd.assign(this.alg.mult(this.dDat, par.phi), Functions.plus);
        for (int i2 = 0; i2 < this.n && !Thread.currentThread().isInterrupted(); ++i2) {
            int j;
            for (j = 0; j < this.xDat.columns(); ++j) {
                tempLoss.set(i2, j, this.xDat.get(i2, j) - par.alpha1.get(j) - xBeta.get(i2, j) - dTheta.get(i2, j));
            }
            for (j = 0; j < this.dDat.columns(); ++j) {
                wxProd.set(i2, j, wxProd.get(i2, j) + par.alpha2.get(j));
            }
        }
        double sqloss = (double)(-this.n) / 2.0 * par.betad.copy().assign(Functions.log).zSum() + 0.5 * FastMath.pow(this.alg.normF(this.alg.mult(tempLoss, this.factory2D.diagonal(par.betad.copy().assign(Functions.sqrt)))), 2);
        tempLoss.assign(Functions.mult(-1.0));
        gradOut.beta = this.alg.mult(this.alg.transpose(this.xDat), tempLoss);
        DoubleMatrix2D lowerBeta = this.alg.transpose(Mgm.lowerTri(gradOut.beta.copy(), -1));
        Mgm.upperTri(gradOut.beta, 1).assign(lowerBeta, Functions.plus);
        gradOut.alpha1 = this.alg.mult(this.factory2D.diagonal(par.betad), Mgm.margSum(tempLoss, 1));
        gradOut.theta = this.alg.mult(this.alg.transpose(this.dDat), tempLoss);
        double catloss = 0.0;
        for (int i3 = 0; i3 < this.yDat.columns() && !Thread.currentThread().isInterrupted(); ++i3) {
            DoubleMatrix2D wxTemp = wxProd.viewPart(0, this.lcumsum[i3], this.n, this.l[i3]);
            DoubleMatrix2D wxTemp0 = wxTemp.copy();
            wxTemp.assign(Functions.exp);
            DoubleMatrix1D invDenom = this.factory1D.make(this.n, 1.0).assign(Mgm.margSum(wxTemp, 2), Functions.div);
            wxTemp.assign(this.alg.mult(this.factory2D.diagonal(invDenom), wxTemp));
            for (int k = 0; k < this.n && !Thread.currentThread().isInterrupted(); ++k) {
                DoubleMatrix1D curRow = wxTemp.viewRow(k);
                DoubleMatrix1D curRow0 = wxTemp0.viewRow(k);
                catloss -= curRow0.get((int)this.yDat.get(k, i3) - 1);
                catloss += this.logsumexp(curRow0);
                curRow.set((int)this.yDat.get(k, i3) - 1, curRow.get((int)this.yDat.get(k, i3) - 1) - 1.0);
            }
        }
        gradOut.alpha2 = Mgm.margSum(wxProd, 1);
        DoubleMatrix2D gradW = this.alg.mult(this.alg.transpose(this.xDat), wxProd);
        gradOut.theta.assign(this.alg.transpose(gradW), Functions.plus);
        gradOut.phi = this.alg.mult(this.alg.transpose(this.dDat), wxProd);
        for (int i4 = 0; i4 < this.q; ++i4) {
            gradOut.phi.viewPart(this.lcumsum[i4], this.lcumsum[i4], this.l[i4], this.l[i4]).assign(0.0);
        }
        DoubleMatrix2D lowerPhi = this.alg.transpose(Mgm.lowerTri(gradOut.phi.copy(), 0));
        Mgm.upperTri(gradOut.phi, 0).assign(lowerPhi, Functions.plus);
        gradOut.betad = this.factory1D.make(this.xDat.columns());
        for (int i5 = 0; i5 < this.p; ++i5) {
            gradOut.betad.set(i5, (double)(-this.n) / (2.0 * par.betad.get(i5)) + this.alg.norm2(tempLoss.viewColumn(i5)) / 2.0 - this.alg.mult(tempLoss.viewColumn(i5), xBeta.viewColumn(i5).copy().assign(dTheta.viewColumn(i5), Functions.plus)));
        }
        gradOut.alpha1.assign(Functions.div(this.n));
        gradOut.alpha2.assign(Functions.div(this.n));
        gradOut.betad.assign(Functions.div(this.n));
        gradOut.beta.assign(Functions.div(this.n));
        gradOut.theta.assign(Functions.div(this.n));
        gradOut.phi.assign(Functions.div(this.n));
        gradOutVec.assign(gradOut.toMatrix1D());
        return (sqloss + catloss) / (double)this.n;
    }

    @Override
    public double nonSmoothValue(DoubleMatrix1D parIn) {
        MGMParams par = new MGMParams(parIn, this.p, this.lsum);
        DoubleMatrix2D weightMat = this.alg.multOuter(this.weights, this.weights, null);
        DoubleMatrix2D betaWeight = weightMat.viewPart(0, 0, this.p, this.p);
        DoubleMatrix2D absBeta = par.beta.copy().assign(Functions.abs);
        double betaNorms = absBeta.assign(betaWeight, Functions.mult).zSum();
        double thetaNorms = 0.0;
        for (int i = 0; i < this.p && !Thread.currentThread().isInterrupted(); ++i) {
            for (int j = 0; j < this.lcumsum.length - 1 && !Thread.currentThread().isInterrupted(); ++j) {
                DoubleMatrix1D tempVec = par.theta.viewColumn(i).viewPart(this.lcumsum[j], this.l[j]);
                thetaNorms += weightMat.get(i, this.p + j) * FastMath.sqrt(this.alg.norm2(tempVec));
            }
        }
        double phiNorms = 0.0;
        for (int i = 0; i < this.lcumsum.length - 1 && !Thread.currentThread().isInterrupted(); ++i) {
            for (int j = i + 1; j < this.lcumsum.length - 1 && !Thread.currentThread().isInterrupted(); ++j) {
                DoubleMatrix2D tempMat = par.phi.viewPart(this.lcumsum[i], this.lcumsum[j], this.l[i], this.l[j]);
                phiNorms += weightMat.get(this.p + i, this.p + j) * this.alg.normF(tempMat);
            }
        }
        return this.lambda.get(0) * betaNorms + this.lambda.get(1) * thetaNorms + this.lambda.get(2) * phiNorms;
    }

    @Override
    public DoubleMatrix1D smoothGradient(DoubleMatrix1D parIn) {
        int n = this.xDat.rows();
        MGMParams grad = new MGMParams();
        MGMParams par = new MGMParams(parIn, this.p, this.lsum);
        Mgm.upperTri(par.beta, 1);
        par.beta.assign(this.alg.transpose(par.beta), Functions.plus);
        for (int i = 0; i < this.q; ++i) {
            par.phi.viewPart(this.lcumsum[i], this.lcumsum[i], this.l[i], this.l[i]).assign(0.0);
        }
        Mgm.upperTri(par.phi, 0);
        par.phi.assign(this.alg.transpose(par.phi), Functions.plus);
        DoubleMatrix2D divBetaD = this.factory2D.diagonal(this.factory1D.make(this.p, 1.0).assign(par.betad, Functions.div));
        DoubleMatrix2D xBeta = this.alg.mult(this.alg.mult(this.xDat, par.beta), divBetaD);
        DoubleMatrix2D dTheta = this.alg.mult(this.alg.mult(this.dDat, par.theta), divBetaD);
        DoubleMatrix2D negLoss = this.factory2D.make(n, this.xDat.columns());
        DoubleMatrix2D wxProd = this.alg.mult(this.xDat, this.alg.transpose(par.theta));
        wxProd.assign(this.alg.mult(this.dDat, par.phi), Functions.plus);
        for (int i = 0; i < n && !Thread.currentThread().isInterrupted(); ++i) {
            int j;
            for (j = 0; j < this.p && !Thread.currentThread().isInterrupted(); ++j) {
                negLoss.set(i, j, xBeta.get(i, j) - this.xDat.get(i, j) + par.alpha1.get(j) + dTheta.get(i, j));
            }
            for (j = 0; j < this.dDat.columns() && !Thread.currentThread().isInterrupted(); ++j) {
                wxProd.set(i, j, wxProd.get(i, j) + par.alpha2.get(j));
            }
        }
        grad.beta = this.alg.mult(this.alg.transpose(this.xDat), negLoss);
        DoubleMatrix2D lowerBeta = this.alg.transpose(Mgm.lowerTri(grad.beta.copy(), -1));
        Mgm.upperTri(grad.beta, 1).assign(lowerBeta, Functions.plus);
        grad.alpha1 = this.alg.mult(this.factory2D.diagonal(par.betad), Mgm.margSum(negLoss, 1));
        grad.theta = this.alg.mult(this.alg.transpose(this.dDat), negLoss);
        for (int i = 0; i < this.yDat.columns(); ++i) {
            DoubleMatrix2D wxTemp = wxProd.viewPart(0, this.lcumsum[i], n, this.l[i]);
            wxTemp.assign(Functions.exp);
            DoubleMatrix1D invDenom = this.factory1D.make(n, 1.0).assign(Mgm.margSum(wxTemp, 2), Functions.div);
            wxTemp.assign(this.alg.mult(this.factory2D.diagonal(invDenom), wxTemp));
            for (int k = 0; k < n; ++k) {
                DoubleMatrix1D curRow = wxTemp.viewRow(k);
                curRow.set((int)this.yDat.get(k, i) - 1, curRow.get((int)this.yDat.get(k, i) - 1) - 1.0);
            }
        }
        grad.alpha2 = Mgm.margSum(wxProd, 1);
        DoubleMatrix2D gradW = this.alg.mult(this.alg.transpose(this.xDat), wxProd);
        grad.theta.assign(this.alg.transpose(gradW), Functions.plus);
        grad.phi = this.alg.mult(this.alg.transpose(this.dDat), wxProd);
        for (int i = 0; i < this.q && !Thread.currentThread().isInterrupted(); ++i) {
            grad.phi.viewPart(this.lcumsum[i], this.lcumsum[i], this.l[i], this.l[i]).assign(0.0);
        }
        DoubleMatrix2D lowerPhi = this.alg.transpose(Mgm.lowerTri(grad.phi.copy(), 0));
        Mgm.upperTri(grad.phi, 0).assign(lowerPhi, Functions.plus);
        grad.betad = this.factory1D.make(this.xDat.columns());
        for (int i = 0; i < this.p && !Thread.currentThread().isInterrupted(); ++i) {
            grad.betad.set(i, (double)(-n) / (2.0 * par.betad.get(i)) + this.alg.norm2(negLoss.viewColumn(i)) / 2.0 - this.alg.mult(negLoss.viewColumn(i), xBeta.viewColumn(i).copy().assign(dTheta.viewColumn(i), Functions.plus)));
        }
        grad.alpha1.assign(Functions.div(n));
        grad.alpha2.assign(Functions.div(n));
        grad.betad.assign(Functions.div(n));
        grad.beta.assign(Functions.div(n));
        grad.theta.assign(Functions.div(n));
        grad.phi.assign(Functions.div(n));
        return grad.toMatrix1D();
    }

    @Override
    public DoubleMatrix1D proximalOperator(double t, DoubleMatrix1D X) {
        int j;
        int i;
        if (t <= 0.0) {
            throw new IllegalArgumentException("t must be positive: " + t);
        }
        DoubleMatrix1D tlam = this.lambda.copy().assign(Functions.mult(t));
        MGMParams par = new MGMParams(X.copy(), this.p, this.lsum);
        DoubleMatrix2D weightMat = this.alg.multOuter(this.weights, this.weights, null);
        DoubleMatrix2D betaWeight = weightMat.viewPart(0, 0, this.p, this.p);
        DoubleMatrix2D betascale = betaWeight.copy().assign(Functions.mult(-tlam.get(0)));
        betascale.assign(par.beta.copy().assign(Functions.abs), Functions.div);
        betascale.assign(Functions.plus(1.0));
        betascale.assign(Functions.max(0.0));
        for (i = 0; i < this.p && !Thread.currentThread().isInterrupted(); ++i) {
            for (j = 0; j < this.p && !Thread.currentThread().isInterrupted(); ++j) {
                double curVal = par.beta.get(i, j);
                if (curVal == 0.0) continue;
                par.beta.set(i, j, curVal * betascale.get(i, j));
            }
        }
        for (i = 0; i < this.p && !Thread.currentThread().isInterrupted(); ++i) {
            for (j = 0; j < this.lcumsum.length - 1 && !Thread.currentThread().isInterrupted(); ++j) {
                DoubleMatrix1D tempVec = par.theta.viewColumn(i).viewPart(this.lcumsum[j], this.l[j]);
                double foo = Mgm.norm2(tempVec);
                double thetaScale = FastMath.max(0.0, 1.0 - tlam.get(1) * weightMat.get(i, this.p + j) / Mgm.norm2(tempVec));
                tempVec.assign(Functions.mult(thetaScale));
            }
        }
        for (i = 0; i < this.lcumsum.length - 1 && !Thread.currentThread().isInterrupted(); ++i) {
            for (j = i + 1; j < this.lcumsum.length - 1 && !Thread.currentThread().isInterrupted(); ++j) {
                DoubleMatrix2D tempMat = par.phi.viewPart(this.lcumsum[i], this.lcumsum[j], this.l[i], this.l[j]);
                double phiScale = FastMath.max(0.0, 1.0 - tlam.get(2) * weightMat.get(this.p + i, this.p + j) / Mgm.norm2(tempMat));
                tempMat.assign(Functions.mult(phiScale));
            }
        }
        return par.toMatrix1D();
    }

    @Override
    public double nonSmooth(double t, DoubleMatrix1D X, DoubleMatrix1D pX) {
        double nonSmooth = 0.0;
        DoubleMatrix1D tlam = this.lambda.copy().assign(Functions.mult(t));
        MGMParams par = new MGMParams(X, this.p, this.lsum);
        DoubleMatrix2D weightMat = this.alg.multOuter(this.weights, this.weights, null);
        DoubleMatrix2D betaWeight = weightMat.viewPart(0, 0, this.p, this.p);
        DoubleMatrix2D betascale = betaWeight.copy().assign(Functions.mult(-tlam.get(0)));
        DoubleMatrix2D absBeta = par.beta.copy().assign(Functions.abs);
        betascale.assign(absBeta, Functions.div);
        betascale.assign(Functions.plus(1.0));
        betascale.assign(Functions.max(0.0));
        double betaNorms = 0.0;
        for (int i = 0; i < this.p && !Thread.currentThread().isInterrupted(); ++i) {
            for (int j = 0; j < this.p && !Thread.currentThread().isInterrupted(); ++j) {
                double curVal = par.beta.get(i, j);
                if (curVal == 0.0) continue;
                par.beta.set(i, j, curVal *= betascale.get(i, j));
                betaNorms += FastMath.abs(betaWeight.get(i, j) * curVal);
            }
        }
        double thetaNorms = 0.0;
        for (int i = 0; i < this.p && !Thread.currentThread().isInterrupted(); ++i) {
            for (int j = 0; j < this.lcumsum.length - 1 && !Thread.currentThread().isInterrupted(); ++j) {
                DoubleMatrix1D tempVec = par.theta.viewColumn(i).viewPart(this.lcumsum[j], this.l[j]);
                double foo = Mgm.norm2(tempVec);
                double thetaScale = FastMath.max(0.0, 1.0 - tlam.get(1) * weightMat.get(i, this.p + j) / Mgm.norm2(tempVec));
                tempVec.assign(Functions.mult(thetaScale));
                thetaNorms += weightMat.get(i, this.p + j) * FastMath.sqrt(this.alg.norm2(tempVec));
            }
        }
        double phiNorms = 0.0;
        for (int i = 0; i < this.lcumsum.length - 1 && !Thread.currentThread().isInterrupted(); ++i) {
            for (int j = i + 1; j < this.lcumsum.length - 1 && !Thread.currentThread().isInterrupted(); ++j) {
                DoubleMatrix2D tempMat = par.phi.viewPart(this.lcumsum[i], this.lcumsum[j], this.l[i], this.l[j]);
                double phiScale = FastMath.max(0.0, 1.0 - tlam.get(2) * weightMat.get(this.p + i, this.p + j) / Mgm.norm2(tempMat));
                tempMat.assign(Functions.mult(phiScale));
                phiNorms += weightMat.get(this.p + i, this.p + j) * this.alg.normF(tempMat);
            }
        }
        pX.assign(par.toMatrix1D());
        return this.lambda.get(0) * betaNorms + this.lambda.get(1) * thetaNorms + this.lambda.get(2) * phiNorms;
    }

    public void learn(double epsilon, int iterLimit) {
        ProximalGradient pg = new ProximalGradient();
        this.setParams(new MGMParams(pg.learnBackTrack(this, this.params.toMatrix1D(), epsilon, iterLimit), this.p, this.lsum));
    }

    public void learnEdges(int iterLimit) {
        ProximalGradient pg = new ProximalGradient(0.5, 0.9, true);
        this.setParams(new MGMParams(pg.learnBackTrack(this, this.params.toMatrix1D(), 0.0, iterLimit), this.p, this.lsum));
    }

    public void learnEdges(int iterLimit, int edgeChangeTol) {
        ProximalGradient pg = new ProximalGradient(0.5, 0.9, true);
        pg.setEdgeChangeTol(edgeChangeTol);
        this.setParams(new MGMParams(pg.learnBackTrack(this, this.params.toMatrix1D(), 0.0, iterLimit), this.p, this.lsum));
    }

    public Graph graphFromMGM() {
        double v1;
        int j;
        int i;
        EdgeListGraph g = new EdgeListGraph(this.variables);
        for (i = 0; i < this.p && !Thread.currentThread().isInterrupted(); ++i) {
            for (j = i + 1; j < this.p && !Thread.currentThread().isInterrupted(); ++j) {
                v1 = this.params.beta.get(i, j);
                if (!(FastMath.abs(v1) > 0.0) || g.isAdjacentTo(this.variables.get(i), this.variables.get(j))) continue;
                g.addUndirectedEdge(this.variables.get(i), this.variables.get(j));
            }
        }
        for (i = 0; i < this.p && !Thread.currentThread().isInterrupted(); ++i) {
            for (j = 0; j < this.q && !Thread.currentThread().isInterrupted(); ++j) {
                v1 = this.params.theta.viewColumn(i).viewPart(this.lcumsum[j], this.l[j]).copy().assign(Functions.abs).zSum();
                if (!(v1 > 0.0) || g.isAdjacentTo(this.variables.get(i), this.variables.get(this.p + j))) continue;
                g.addUndirectedEdge(this.variables.get(i), this.variables.get(this.p + j));
            }
        }
        for (i = 0; i < this.q && !Thread.currentThread().isInterrupted(); ++i) {
            for (j = i + 1; j < this.q && !Thread.currentThread().isInterrupted(); ++j) {
                v1 = this.params.phi.viewPart(this.lcumsum[i], this.lcumsum[j], this.l[i], this.l[j]).copy().assign(Functions.abs).zSum();
                if (!(v1 > 0.0) || g.isAdjacentTo(this.variables.get(this.p + i), this.variables.get(this.p + j))) continue;
                g.addUndirectedEdge(this.variables.get(this.p + i), this.variables.get(this.p + j));
            }
        }
        return g;
    }

    public DoubleMatrix2D adjMatFromMGM() {
        double val;
        int j;
        int i;
        DoubleMatrix2D outMat = DoubleFactory2D.dense.make(this.p + this.q, this.p + this.q);
        outMat.viewPart(0, 0, this.p, this.p).assign(this.params.beta.copy().assign(this.alg.transpose(this.params.beta), Functions.plus));
        for (i = 0; i < this.p && !Thread.currentThread().isInterrupted(); ++i) {
            for (j = 0; j < this.q && !Thread.currentThread().isInterrupted(); ++j) {
                val = Mgm.norm2(this.params.theta.viewColumn(i).viewPart(this.lcumsum[j], this.l[j]));
                outMat.set(i, this.p + j, val);
                outMat.set(this.p + j, i, val);
            }
        }
        for (i = 0; i < this.q && !Thread.currentThread().isInterrupted(); ++i) {
            for (j = i + 1; j < this.q && !Thread.currentThread().isInterrupted(); ++j) {
                val = this.alg.normF(this.params.phi.viewPart(this.lcumsum[i], this.lcumsum[j], this.l[i], this.l[j]));
                outMat.set(this.p + i, this.p + j, val);
                outMat.set(this.p + j, this.p + i, val);
            }
        }
        if (this.initVariables != null) {
            int[] varMap = new int[this.p + this.q];
            for (int i2 = 0; i2 < this.p + this.q; ++i2) {
                varMap[i2] = this.variables.indexOf(this.initVariables.get(i2));
            }
            outMat = outMat.viewSelection(varMap, varMap);
        }
        return outMat;
    }

    @Override
    public Graph search() {
        long startTime = MillisecondTimes.timeMillis();
        this.learnEdges(1000);
        this.elapsedTime = MillisecondTimes.timeMillis() - startTime;
        return this.graphFromMGM();
    }

    public long getElapsedTime() {
        return this.elapsedTime;
    }

    public static class MGMParams {
        private DoubleMatrix2D beta;
        private DoubleMatrix1D betad;
        private DoubleMatrix2D theta;
        private DoubleMatrix2D phi;
        private DoubleMatrix1D alpha1;
        private DoubleMatrix1D alpha2;

        public MGMParams() {
        }

        public MGMParams(DoubleMatrix2D beta, DoubleMatrix1D betad, DoubleMatrix2D theta, DoubleMatrix2D phi, DoubleMatrix1D alpha1, DoubleMatrix1D alpha2) {
            this.beta = beta;
            this.betad = betad;
            this.theta = theta;
            this.phi = phi;
            this.alpha1 = alpha1;
            this.alpha2 = alpha2;
        }

        public MGMParams(MGMParams parIn) {
            this.beta = parIn.beta.copy();
            this.betad = parIn.betad.copy();
            this.theta = parIn.theta.copy();
            this.phi = parIn.phi.copy();
            this.alpha1 = parIn.alpha1.copy();
            this.alpha2 = parIn.alpha2.copy();
        }

        public MGMParams(DoubleMatrix1D vec, int p, int ltot) {
            int[] lens = new int[]{p * p, p, p * ltot, ltot * ltot, p, ltot};
            int[] lenSums = new int[lens.length];
            lenSums[0] = lens[0];
            for (int i = 1; i < lenSums.length; ++i) {
                lenSums[i] = lens[i] + lenSums[i - 1];
            }
            if (vec.size() != lenSums[5]) {
                throw new IllegalArgumentException("Param vector dimension doesn't match: Found " + vec.size() + " need " + lenSums[5]);
            }
            this.beta = DoubleFactory2D.dense.make(vec.viewPart(0, lens[0]).toArray(), p);
            this.betad = vec.viewPart(lenSums[0], lens[1]).copy();
            this.theta = DoubleFactory2D.dense.make(vec.viewPart(lenSums[1], lens[2]).toArray(), ltot);
            this.phi = DoubleFactory2D.dense.make(vec.viewPart(lenSums[2], lens[3]).toArray(), ltot);
            this.alpha1 = vec.viewPart(lenSums[3], lens[4]).copy();
            this.alpha2 = vec.viewPart(lenSums[4], lens[5]).copy();
        }

        public String toString() {
            String outStr = "alpha1: " + this.alpha1.toString();
            outStr = outStr + "\nalpha2: " + this.alpha2.toString();
            outStr = outStr + "\nbeta: " + this.beta.toString();
            outStr = outStr + "\nbetad: " + this.betad.toString();
            outStr = outStr + "\ntheta: " + this.theta.toString();
            outStr = outStr + "\nphi: " + this.phi.toString();
            return outStr;
        }

        public DoubleMatrix1D getAlpha1() {
            return this.alpha1;
        }

        public void setAlpha1(DoubleMatrix1D alpha1) {
            this.alpha1 = alpha1;
        }

        public DoubleMatrix1D getAlpha2() {
            return this.alpha2;
        }

        public void setAlpha2(DoubleMatrix1D alpha2) {
            this.alpha2 = alpha2;
        }

        public DoubleMatrix1D getBetad() {
            return this.betad;
        }

        public void setBetad(DoubleMatrix1D betad) {
            this.betad = betad;
        }

        public DoubleMatrix2D getBeta() {
            return this.beta;
        }

        public void setBeta(DoubleMatrix2D beta) {
            this.beta = beta;
        }

        public DoubleMatrix2D getPhi() {
            return this.phi;
        }

        public void setPhi(DoubleMatrix2D phi) {
            this.phi = phi;
        }

        public DoubleMatrix2D getTheta() {
            return this.theta;
        }

        public void setTheta(DoubleMatrix2D theta) {
            this.theta = theta;
        }

        public DoubleMatrix1D toMatrix1D() {
            DoubleFactory1D fac = DoubleFactory1D.dense;
            int p = this.alpha1.size();
            int ltot = this.alpha2.size();
            int[] lens = new int[]{p * p, p, p * ltot, ltot * ltot, p, ltot};
            int[] lenSums = new int[lens.length];
            lenSums[0] = lens[0];
            for (int i = 1; i < lenSums.length; ++i) {
                lenSums[i] = lens[i] + lenSums[i - 1];
            }
            DoubleMatrix1D outVec = fac.make(p * p + p + p * ltot + ltot * ltot + p + ltot);
            outVec.viewPart(0, lens[0]).assign(Mgm.flatten(this.beta));
            outVec.viewPart(lenSums[0], lens[1]).assign(this.betad);
            outVec.viewPart(lenSums[1], lens[2]).assign(Mgm.flatten(this.theta));
            outVec.viewPart(lenSums[2], lens[3]).assign(Mgm.flatten(this.phi));
            outVec.viewPart(lenSums[3], lens[4]).assign(this.alpha1);
            outVec.viewPart(lenSums[4], lens[5]).assign(this.alpha2);
            return outVec;
        }

        public double[][] toVector() {
            double[][] outArr = new double[][]{this.toMatrix1D().toArray()};
            return outArr;
        }
    }
}

