/*
 * Decompiled with CFR 0.152.
 */
package edu.pitt.csb.stability;

import cern.colt.matrix.DoubleFactory2D;
import cern.colt.matrix.DoubleMatrix2D;
import cern.colt.matrix.linalg.Algebra;
import cern.jet.math.Functions;
import edu.cmu.tetrad.data.DataSet;
import edu.cmu.tetrad.graph.Graph;
import edu.cmu.tetrad.graph.GraphSaveLoadUtils;
import edu.cmu.tetrad.graph.Node;
import edu.cmu.tetrad.util.ForkJoinPoolInstance;
import edu.cmu.tetrad.util.MillisecondTimes;
import edu.cmu.tetrad.util.RandomUtil;
import edu.pitt.csb.mgm.Mgm;
import edu.pitt.csb.mgm.MixedUtils;
import edu.pitt.csb.stability.DataGraphSearch;
import edu.pitt.csb.stability.SearchWrappers;
import java.io.File;
import java.util.ArrayList;
import java.util.Arrays;
import java.util.List;
import java.util.concurrent.ForkJoinPool;
import java.util.concurrent.ForkJoinTask;
import java.util.concurrent.RecursiveAction;

public class StabilityUtils {
    public static DoubleMatrix2D StabilitySearch(DataSet data, DataGraphSearch gs, int N, int b) {
        int numVars = data.getNumColumns();
        DoubleMatrix2D thetaMat = DoubleFactory2D.dense.make(numVars, numVars, 0.0);
        int[][] samps = StabilityUtils.subSampleNoReplacement(data.getNumRows(), b, N);
        for (int s = 0; s < N; ++s) {
            DataSet dataSubSamp = data.subsetRows(samps[s]);
            Graph g = gs.search(dataSubSamp);
            DoubleMatrix2D curAdj = MixedUtils.skeletonToMatrix(g);
            thetaMat.assign(curAdj, Functions.plus);
        }
        thetaMat.assign(Functions.mult(1.0 / (double)N));
        return thetaMat;
    }

    public static DoubleMatrix2D StabilitySearchPar(DataSet data, DataGraphSearch gs, int N, int b) {
        int numVars = data.getNumColumns();
        DoubleMatrix2D thetaMat = DoubleFactory2D.dense.make(numVars, numVars, 0.0);
        int[][] samps = StabilityUtils.subSampleNoReplacement(data.getNumRows(), b, N);
        ForkJoinPool pool = ForkJoinPoolInstance.getInstance().getPool();
        int chunk = 2;
        class StabilityAction
        extends RecursiveAction {
            private final int chunk;
            private final int from;
            private final int to;
            final /* synthetic */ DataSet val$data;
            final /* synthetic */ int[][] val$samps;
            final /* synthetic */ DataGraphSearch val$gs;
            final /* synthetic */ DoubleMatrix2D val$thetaMat;

            public StabilityAction(int chunk, int from, int to) {
                this.val$data = dataSet;
                this.val$samps = nArray;
                this.val$gs = dataGraphSearch;
                this.val$thetaMat = doubleMatrix2D;
                this.chunk = chunk;
                this.from = from;
                this.to = to;
            }

            private synchronized void addToMat(DoubleMatrix2D matSum, DoubleMatrix2D curMat) {
                matSum.assign(curMat, Functions.plus);
            }

            @Override
            protected void compute() {
                if (this.to - this.from <= this.chunk) {
                    for (int s = this.from; s < this.to; ++s) {
                        DataSet dataSubSamp = this.val$data.subsetRows(this.val$samps[s]).copy();
                        DataGraphSearch curGs = this.val$gs.copy();
                        Graph g = curGs.search(dataSubSamp);
                        DoubleMatrix2D curAdj = MixedUtils.skeletonToMatrix(g);
                        this.addToMat(this.val$thetaMat, curAdj);
                    }
                    return;
                }
                ArrayList<StabilityAction> tasks = new ArrayList<StabilityAction>();
                int mid = (this.to + this.from) / 2;
                tasks.add(new StabilityAction(this.chunk, this.from, mid));
                tasks.add(new StabilityAction(this.chunk, mid, this.to));
                ForkJoinTask.invokeAll(tasks);
            }
        }
        pool.invoke(new StabilityAction(2, 0, N));
        thetaMat.assign(Functions.mult(1.0 / (double)N));
        return thetaMat;
    }

    public static double[] totalInstabilityUndir(DoubleMatrix2D xi, List<Node> vars) {
        if (vars.size() != xi.columns() || vars.size() != xi.rows()) {
            throw new IllegalArgumentException("stability mat must have same number of rows and columns as there are vars");
        }
        Algebra al = new Algebra();
        DoubleMatrix2D xiu = xi.copy().assign(xi.copy().assign(Functions.minus(1.0)), Functions.mult).assign(Functions.mult(-2.0));
        double[] D = new double[4];
        int[] discInds = MixedUtils.getDiscreteInds(vars);
        int[] contInds = MixedUtils.getContinuousInds(vars);
        int p = contInds.length;
        int q = discInds.length;
        double temp = Mgm.upperTri(xiu.copy(), 1).zSum();
        D[0] = temp / (((double)(p + q) - 1.0) * (double)(p + q) / 2.0);
        temp = Mgm.upperTri(xiu.viewSelection(contInds, contInds).copy(), 1).zSum();
        D[1] = temp / ((double)p * ((double)p - 1.0) / 2.0);
        temp = xiu.viewSelection(contInds, discInds).zSum();
        D[2] = temp / (double)(p * q);
        temp = Mgm.upperTri(xiu.viewSelection(discInds, discInds).copy(), 1).zSum();
        D[3] = temp / ((double)q * ((double)q - 1.0) / 2.0);
        return D;
    }

    public static double[] totalInstabilityDir(DoubleMatrix2D xi, List<Node> vars) {
        if (vars.size() != xi.columns() || vars.size() != xi.rows()) {
            throw new IllegalArgumentException("stability mat must have same number of rows and columns as there are vars");
        }
        double[] D = new double[4];
        int[] discInds = MixedUtils.getDiscreteInds(vars);
        int[] contInds = MixedUtils.getContinuousInds(vars);
        int p = contInds.length;
        int q = discInds.length;
        D[0] = xi.zSum() / ((double)((p + q - 1) * (p + q)) / 2.0);
        D[1] = xi.viewSelection(contInds, contInds).zSum() / (double)(p * (p - 1));
        D[2] = xi.viewSelection(contInds, discInds).zSum() / (double)(p * q);
        D[3] = xi.viewSelection(discInds, discInds).zSum() / (double)(q * (q - 1));
        return D;
    }

    public static int[][] subSampleNoReplacement(int sampSize, int subSize, int numSub) {
        if (subSize < 1) {
            throw new IllegalArgumentException("Sample size must be > 0.");
        }
        ArrayList<Integer> indices = new ArrayList<Integer>(sampSize);
        for (int i = 0; i < sampSize; ++i) {
            indices.add(i);
        }
        int[][] sampMat = new int[numSub][subSize];
        for (int i = 0; i < numSub; ++i) {
            int[] curSamp;
            RandomUtil.shuffle(indices);
            block2: while (true) {
                curSamp = StabilityUtils.subSampleIndices(sampSize, subSize);
                for (int j = 0; j < i; ++j) {
                    if (Arrays.equals(curSamp, sampMat[j])) continue block2;
                }
                break;
            }
            sampMat[i] = curSamp;
        }
        return sampMat;
    }

    private static int[] subSampleIndices(int N, int subSize) {
        ArrayList<Integer> indices = new ArrayList<Integer>(N);
        for (int i = 0; i < N; ++i) {
            indices.add(i);
        }
        RandomUtil.shuffle(indices);
        int[] samp = new int[subSize];
        for (int i = 0; i < subSize; ++i) {
            samp[i] = (Integer)indices.get(i);
        }
        return samp;
    }

    public static void main(String[] args) {
        String fn = "/Users/ajsedgewick/tetrad_mgm_runs/run2/networks/DAG_0_graph.txt";
        Graph trueGraph = GraphSaveLoadUtils.loadGraphTxt(new File("/Users/ajsedgewick/tetrad_mgm_runs/run2/networks/DAG_0_graph.txt"));
        DataSet ds = null;
        try {
            ds = MixedUtils.loadData("/Users/ajsedgewick/tetrad_mgm_runs/run2/data/", "DAG_0_data.txt");
        }
        catch (Throwable t) {
            t.printStackTrace();
        }
        double lambda = 0.1;
        SearchWrappers.MGMWrapper mgm = new SearchWrappers.MGMWrapper(0.1, 0.1, 0.1);
        long start = MillisecondTimes.timeMillis();
        DoubleMatrix2D xi = StabilityUtils.StabilitySearch(ds, mgm, 8, 200);
        long end = MillisecondTimes.timeMillis();
        System.out.println("Not parallel: " + (double)(end - start) / 1000.0);
        start = MillisecondTimes.timeMillis();
        DoubleMatrix2D xi2 = StabilityUtils.StabilitySearchPar(ds, mgm, 8, 200);
        end = MillisecondTimes.timeMillis();
        System.out.println("Parallel: " + (double)(end - start) / 1000.0);
        System.out.println(xi);
        System.out.println(xi2);
    }
}

