/*
 * Decompiled with CFR 0.152.
 */
package edu.cmu.tetrad.sem;

import edu.cmu.tetrad.data.BoxDataSet;
import edu.cmu.tetrad.data.ContinuousVariable;
import edu.cmu.tetrad.data.DataSet;
import edu.cmu.tetrad.data.DataUtils;
import edu.cmu.tetrad.data.Knowledge;
import edu.cmu.tetrad.data.VerticalDoubleDataBox;
import edu.cmu.tetrad.graph.Edge;
import edu.cmu.tetrad.graph.Edges;
import edu.cmu.tetrad.graph.Graph;
import edu.cmu.tetrad.graph.GraphUtils;
import edu.cmu.tetrad.graph.Node;
import edu.cmu.tetrad.graph.Paths;
import edu.cmu.tetrad.graph.SemGraph;
import edu.cmu.tetrad.graph.TimeLagGraph;
import edu.cmu.tetrad.util.ForkJoinPoolInstance;
import edu.cmu.tetrad.util.Matrix;
import edu.cmu.tetrad.util.RandomUtil;
import edu.cmu.tetrad.util.Vector;
import edu.cmu.tetrad.util.dist.Split;
import edu.cmu.tetrad.util.dist.Uniform;
import java.io.PrintStream;
import java.util.ArrayList;
import java.util.Collections;
import java.util.Date;
import java.util.List;
import java.util.concurrent.RecursiveTask;
import org.apache.commons.collections4.map.HashedMap;
import org.apache.commons.math3.distribution.AbstractRealDistribution;
import org.apache.commons.math3.distribution.NormalDistribution;
import org.apache.commons.math3.distribution.UniformRealDistribution;
import org.apache.commons.math3.random.Well1024a;
import org.apache.commons.math3.util.FastMath;

public final class LargeScaleSimulation {
    static final long serialVersionUID = 23L;
    private int[][] parents;
    private double[][] coefs;
    private double[] errorVars;
    private double[] means;
    private final List<Node> variableNodes;
    private final Graph graph;
    private double coefLow;
    private double coefHigh = 1.0;
    private double varLow = 1.0;
    private double varHigh = 3.0;
    private double meanLow;
    private double meanHigh;
    private PrintStream out = System.out;
    private int[] tierIndices;
    private boolean verbose;
    private long seed = new Date().getTime();
    private boolean alreadySetUp;
    private boolean includePositiveCoefs = true;
    private boolean includeNegativeCoefs = true;
    private boolean errorsNormal = true;
    private double selfLoopCoef;

    public LargeScaleSimulation(Graph graph) {
        this.graph = graph;
        this.variableNodes = graph.getNodes();
        if (graph instanceof SemGraph) {
            ((SemGraph)graph).setShowErrorTerms(false);
        }
        Paths paths = graph.paths();
        List<Node> initialOrder = graph.getNodes();
        List<Node> causalOrdering = paths.validOrder(initialOrder, true);
        this.tierIndices = new int[causalOrdering.size()];
        for (int i = 0; i < this.tierIndices.length; ++i) {
            this.tierIndices[i] = this.variableNodes.indexOf(causalOrdering.get(i));
        }
    }

    public LargeScaleSimulation(Graph graph, List<Node> nodes, int[] tierIndices) {
        if (graph == null) {
            throw new NullPointerException("Graph must not be null.");
        }
        this.graph = GraphUtils.replaceNodes(graph, nodes);
        this.variableNodes = nodes;
        this.tierIndices = tierIndices;
        if (graph instanceof SemGraph) {
            ((SemGraph)graph).setShowErrorTerms(false);
        }
    }

    public DataSet simulateDataRecursive(int sampleSize) {
        if (this.tierIndices == null) {
            List<Node> nodes = this.graph.getNodes();
            this.tierIndices = new int[nodes.size()];
            for (int j = 0; j < nodes.size(); ++j) {
                this.tierIndices[j] = j;
            }
        }
        int size = this.variableNodes.size();
        this.setupModel(size);
        if (this.graph instanceof TimeLagGraph) {
            sampleSize += 200;
        }
        double[][] all = new double[this.variableNodes.size()][sampleSize];
        int chunk = sampleSize / Runtime.getRuntime().availableProcessors();
        class SimulateTask
        extends RecursiveTask<Boolean> {
            private final int from;
            private final int to;
            private final double[][] all;
            private final int chunk;

            public SimulateTask(int from, int to, double[][] all, int chunk) {
                this.from = from;
                this.to = to;
                this.all = all;
                this.chunk = chunk;
            }

            @Override
            protected Boolean compute() {
                if (this.from - this.to > this.chunk) {
                    int mid = this.from + this.to / 2;
                    SimulateTask left = new SimulateTask(this.from, mid, this.all, this.chunk);
                    SimulateTask right = new SimulateTask(mid, this.to, this.all, this.chunk);
                    left.fork();
                    right.compute();
                    left.join();
                } else {
                    for (int i = this.from; i < this.to; ++i) {
                        NormalDistribution normal = new NormalDistribution(new Well1024a(++LargeScaleSimulation.this.seed), 0.0, 1.0);
                        normal.sample();
                        if (LargeScaleSimulation.this.verbose && (i + 1) % 50 == 0) {
                            System.out.println("Simulating " + (i + 1));
                        }
                        for (int col : LargeScaleSimulation.this.tierIndices) {
                            double value = normal.sample() * FastMath.sqrt(LargeScaleSimulation.this.errorVars[col]);
                            for (int j = 0; j < LargeScaleSimulation.this.parents[col].length; ++j) {
                                value += this.all[LargeScaleSimulation.this.parents[col][j]][i] * LargeScaleSimulation.this.coefs[col][j];
                            }
                            this.all[col][i] = value += LargeScaleSimulation.this.means[col];
                        }
                    }
                }
                return true;
            }
        }
        ForkJoinPoolInstance.getInstance().getPool().invoke(new SimulateTask(0, sampleSize, all, chunk));
        if (this.graph instanceof TimeLagGraph) {
            int[] rem = new int[200];
            for (int i = 0; i < 200; ++i) {
                rem[i] = i;
            }
            BoxDataSet dat = new BoxDataSet(new VerticalDoubleDataBox(all), this.variableNodes);
            dat.removeRows(rem);
            return dat;
        }
        return new BoxDataSet(new VerticalDoubleDataBox(all), this.variableNodes);
    }

    public DataSet simulateDataReducedForm(int sampleSize) {
        if (sampleSize < 1) {
            throw new IllegalArgumentException("Sample size must be >= 1: " + sampleSize);
        }
        int size = this.variableNodes.size();
        this.setupModel(size);
        NormalDistribution normal = new NormalDistribution(new Well1024a(++this.seed), 0.0, 1.0);
        Matrix B = new Matrix(this.getCoefficientMatrix());
        Matrix iMinusBInv = Matrix.identity(B.rows()).minus(B).inverse();
        double[][] all = new double[this.variableNodes.size()][sampleSize];
        for (int row = 0; row < sampleSize; ++row) {
            Vector e = new Vector(B.rows());
            for (int j = 0; j < e.size(); ++j) {
                e.set(j, normal.sample() * FastMath.sqrt(this.errorVars[j]));
            }
            Vector x = iMinusBInv.times(e);
            for (int j = 0; j < x.size(); ++j) {
                all[j][row] = x.get(j);
            }
        }
        ArrayList<Node> continuousVars = new ArrayList<Node>();
        for (Node node : this.getVariableNodes()) {
            ContinuousVariable var = new ContinuousVariable(node.getName());
            var.setNodeType(node.getNodeType());
            continuousVars.add(var);
        }
        BoxDataSet boxDataSet = new BoxDataSet(new VerticalDoubleDataBox(all), continuousVars);
        return DataUtils.restrictToMeasured(boxDataSet);
    }

    public DataSet simulateDataFisher(int sampleSize) {
        return this.simulateDataFisher(this.getSoCalledPoissonShocks(sampleSize), 50, 1.0E-5);
    }

    public DataSet simulateDataFisher(double[][] shocks, int intervalBetweenShocks, double epsilon) {
        if (intervalBetweenShocks < 1) {
            throw new IllegalArgumentException("Interval between shocks must be >= 1: " + intervalBetweenShocks);
        }
        if (epsilon <= 0.0) {
            throw new IllegalArgumentException("Epsilon must be > 0: " + epsilon);
        }
        int size = this.variableNodes.size();
        if (shocks[0].length != size) {
            throw new IllegalArgumentException("The number of columns in the shocks matrix does not equal the number of variables.");
        }
        this.setupModel(size);
        double[] t1 = new double[this.variableNodes.size()];
        double[] t2 = new double[this.variableNodes.size()];
        double[][] all = new double[this.variableNodes.size()][shocks.length];
        for (int row = 0; row < shocks.length; ++row) {
            System.arraycopy(shocks[row], 0, t2, 0, t1.length);
            for (int i = 0; i < intervalBetweenShocks; ++i) {
                for (int j = 0; j < t1.length; ++j) {
                    for (int k = 0; k < this.parents[j].length; ++k) {
                        int n = j;
                        t2[n] = t2[n] + t1[this.parents[j][k]] * this.coefs[j][k];
                    }
                }
                boolean converged = true;
                for (int j = 0; j < t1.length; ++j) {
                    if (!(FastMath.abs(t2[j] - t1[j]) > epsilon)) continue;
                    converged = false;
                    break;
                }
                double[] t3 = t1;
                t1 = t2;
                t2 = t3;
                if (converged) break;
            }
            for (int j = 0; j < t1.length; ++j) {
                all[j][row] = t2[j];
            }
        }
        ArrayList<Node> continuousVars = new ArrayList<Node>();
        for (Node node : this.getVariableNodes()) {
            ContinuousVariable var = new ContinuousVariable(node.getName());
            var.setNodeType(node.getNodeType());
            continuousVars.add(var);
        }
        BoxDataSet boxDataSet = new BoxDataSet(new VerticalDoubleDataBox(all), continuousVars);
        return DataUtils.restrictToMeasured(boxDataSet);
    }

    public DataSet simulateDataFisher(int intervalBetweenShocks, int intervalBetweenRecordings, int sampleSize, double epsilon, boolean saveLatentVars) {
        if (intervalBetweenShocks < 1) {
            throw new IllegalArgumentException("Interval between shocks must be >= 1: " + intervalBetweenShocks);
        }
        if (epsilon <= 0.0) {
            throw new IllegalArgumentException("Epsilon must be > 0: " + epsilon);
        }
        int size = this.variableNodes.size();
        this.setupModel(size);
        double[] t1 = new double[this.variableNodes.size()];
        double[] t2 = new double[this.variableNodes.size()];
        double[][] all = new double[this.variableNodes.size()][sampleSize];
        int s = 0;
        int shockIndex = 0;
        int recordingIndex = 0;
        double[] shock = this.getUncorrelatedShocks(1)[0];
        System.arraycopy(shock, 0, t1, 0, t1.length);
        while (s < sampleSize) {
            int j;
            if (++recordingIndex % intervalBetweenRecordings == 0) {
                for (j = 0; j < t1.length; ++j) {
                    double[] dArray = all[j];
                    int n = s;
                    dArray[n] = dArray[n] + t1[j];
                }
                ++s;
            }
            if (++shockIndex % intervalBetweenShocks == 0) {
                shock = this.getUncorrelatedShocks(1)[0];
                for (j = 0; j < t1.length; ++j) {
                    int n = j;
                    t1[n] = t1[n] + shock[j];
                }
            }
            for (j = 0; j < t1.length; ++j) {
                t2[j] = shock[j];
                int n = j;
                t2[n] = t2[n] + this.getSelfLoopCoef() * t1[j];
                for (int k = 0; k < this.parents[j].length; ++k) {
                    int n2 = j;
                    t2[n2] = t2[n2] + t1[this.parents[j][k]] * this.coefs[j][k];
                }
            }
            double[] t3 = t1;
            t1 = t2;
            t2 = t3;
        }
        ArrayList<Node> continuousVars = new ArrayList<Node>();
        for (Node node : this.getVariableNodes()) {
            ContinuousVariable var = new ContinuousVariable(node.getName());
            var.setNodeType(node.getNodeType());
            continuousVars.add(var);
        }
        BoxDataSet boxDataSet = new BoxDataSet(new VerticalDoubleDataBox(all), continuousVars);
        return saveLatentVars ? boxDataSet : DataUtils.restrictToMeasured(boxDataSet);
    }

    private void setupModel(int size) {
        int i;
        if (this.alreadySetUp) {
            return;
        }
        HashedMap<Node, Integer> nodesHash = new HashedMap<Node, Integer>();
        for (i = 0; i < this.variableNodes.size(); ++i) {
            nodesHash.put(this.variableNodes.get(i), i);
        }
        this.parents = new int[size][];
        this.coefs = new double[size][];
        this.errorVars = new double[size];
        this.means = new double[size];
        for (i = 0; i < size; ++i) {
            this.parents[i] = new int[0];
            this.coefs[i] = new double[0];
        }
        Split edgeCoefDist = new Split(this.coefLow, this.coefHigh);
        Uniform errorCovarDist = new Uniform(this.varLow, this.varHigh);
        Uniform meanDist = new Uniform(this.meanLow, this.meanHigh);
        for (Edge edge : this.graph.getEdges()) {
            Node tail = Edges.getDirectedEdgeTail(edge);
            Node head = Edges.getDirectedEdgeHead(edge);
            int _tail = (Integer)nodesHash.get(tail);
            int _head = (Integer)nodesHash.get(head);
            int[] parents = this.parents[_head];
            int[] newParents = new int[parents.length + 1];
            System.arraycopy(parents, 0, newParents, 0, parents.length);
            newParents[newParents.length - 1] = _tail;
            double[] coefs = this.coefs[_head];
            double[] newCoefs = new double[coefs.length + 1];
            System.arraycopy(coefs, 0, newCoefs, 0, coefs.length);
            double coef = edgeCoefDist.nextRandom();
            if (this.includePositiveCoefs && !this.includeNegativeCoefs) {
                coef = FastMath.abs(coef);
            } else if (!this.includePositiveCoefs && this.includeNegativeCoefs) {
                coef = -FastMath.abs(coef);
            } else if (!this.includePositiveCoefs) {
                coef = 0.0;
            }
            newCoefs[newCoefs.length - 1] = coef;
            this.parents[_head] = newParents;
            this.coefs[_head] = newCoefs;
        }
        if (this.graph instanceof TimeLagGraph) {
            TimeLagGraph lagGraph = (TimeLagGraph)this.graph;
            Knowledge knowledge = this.getKnowledge(lagGraph);
            List<Node> lag0 = lagGraph.getLag0Nodes();
            for (Node y : lag0) {
                List<Node> _parents = lagGraph.getParents(y);
                for (Node x : _parents) {
                    List<List<Node>> similar = this.returnSimilarPairs(x, y, knowledge);
                    int _x = this.variableNodes.indexOf(x);
                    int _y = this.variableNodes.indexOf(y);
                    double first = Double.NaN;
                    for (int i2 = 0; i2 < this.parents[_y].length; ++i2) {
                        if (_x != this.parents[_y][i2]) continue;
                        first = this.coefs[_y][i2];
                    }
                    for (int j = 0; j < similar.get(0).size(); ++j) {
                        int _xx = this.variableNodes.indexOf(similar.get(0).get(j));
                        int _yy = this.variableNodes.indexOf(similar.get(1).get(j));
                        for (int i3 = 0; i3 < this.parents[_yy].length; ++i3) {
                            if (_xx != this.parents[_yy][i3]) continue;
                            this.coefs[_yy][i3] = first;
                        }
                    }
                }
            }
        }
        for (int i4 = 0; i4 < size; ++i4) {
            this.errorVars[i4] = errorCovarDist.nextRandom();
            this.means[i4] = meanDist.nextRandom();
        }
        this.alreadySetUp = true;
    }

    public Graph getGraph() {
        return this.graph;
    }

    public void setCoefRange(double coefLow, double coefHigh) {
        this.coefLow = coefLow;
        this.coefHigh = coefHigh;
    }

    public void setVarRange(double varLow, double varHigh) {
        this.varLow = varLow;
        this.varHigh = varHigh;
    }

    public void setMeanRange(double meanLow, double meanHigh) {
        this.meanLow = meanLow;
        this.meanHigh = meanHigh;
    }

    public void setOut(PrintStream out) {
        this.out = out;
    }

    public PrintStream getOut() {
        return this.out;
    }

    public boolean isVerbose() {
        return this.verbose;
    }

    public void setVerbose(boolean verbose) {
        this.verbose = verbose;
    }

    public double[][] getCoefficientMatrix() {
        double[][] c = new double[this.coefs.length][this.coefs.length];
        for (int i = 0; i < this.coefs.length; ++i) {
            for (int j = 0; j < this.coefs[i].length; ++j) {
                c[i][this.parents[i][j]] = this.coefs[i][j];
            }
        }
        return c;
    }

    public List<Node> getVariableNodes() {
        return this.variableNodes;
    }

    private List<List<Node>> returnSimilarPairs(Node x, Node y, Knowledge knowledge) {
        int i;
        System.out.println("$$$$$ Entering returnSimilarPairs method with x,y = " + x + ", " + y);
        if (x.getName().equals("time") || y.getName().equals("time")) {
            return new ArrayList<List<Node>>();
        }
        int ntiers = knowledge.getNumTiers();
        int indx_tier = knowledge.isInWhichTier(x);
        int indy_tier = knowledge.isInWhichTier(y);
        int tier_diff = FastMath.max(indx_tier, indy_tier) - FastMath.min(indx_tier, indy_tier);
        int indx_comp = -1;
        int indy_comp = -1;
        List<String> tier_x = knowledge.getTier(indx_tier);
        List<String> tier_y = knowledge.getTier(indy_tier);
        for (i = 0; i < tier_x.size(); ++i) {
            if (!this.getNameNoLag(x.getName()).equals(this.getNameNoLag(tier_x.get(i)))) continue;
            indx_comp = i;
            break;
        }
        for (i = 0; i < tier_y.size(); ++i) {
            if (!this.getNameNoLag(y.getName()).equals(this.getNameNoLag(tier_y.get(i)))) continue;
            indy_comp = i;
            break;
        }
        System.out.println("original independence: " + x + " and " + y);
        if (indx_comp == -1) {
            System.out.println("WARNING: indx_comp = -1!!!! ");
        }
        if (indy_comp == -1) {
            System.out.println("WARNING: indy_comp = -1!!!! ");
        }
        ArrayList<Node> simListX = new ArrayList<Node>();
        ArrayList<Node> simListY = new ArrayList<Node>();
        for (i = 0; i < ntiers - tier_diff; ++i) {
            List<String> tmp_tier2;
            List<String> tmp_tier1;
            if (knowledge.getTier(i).size() == 1) continue;
            if (indx_tier >= indy_tier) {
                tmp_tier1 = knowledge.getTier(i + tier_diff);
                tmp_tier2 = knowledge.getTier(i);
            } else {
                tmp_tier1 = knowledge.getTier(i);
                tmp_tier2 = knowledge.getTier(i + tier_diff);
            }
            String A = tmp_tier1.get(indx_comp);
            String B = tmp_tier2.get(indy_comp);
            if (A.equals(B) || A.equals(tier_x.get(indx_comp)) && B.equals(tier_y.get(indy_comp)) || B.equals(tier_x.get(indx_comp)) && A.equals(tier_y.get(indy_comp))) continue;
            Node x1 = this.graph.getNode(A);
            Node y1 = this.graph.getNode(B);
            System.out.println("Adding pair to simList = " + x1 + " and " + y1);
            simListX.add(x1);
            simListY.add(y1);
        }
        ArrayList<List<Node>> pairList = new ArrayList<List<Node>>();
        pairList.add(simListX);
        pairList.add(simListY);
        return pairList;
    }

    public String getNameNoLag(Object obj) {
        String tempS = obj.toString();
        if (tempS.indexOf(58) == -1) {
            return tempS;
        }
        return tempS.substring(0, tempS.indexOf(58));
    }

    public Knowledge getKnowledge(Graph graph) {
        String tmp;
        int lag;
        String varName;
        List<Node> variables = graph.getNodes();
        ArrayList<Integer> laglist = new ArrayList<Integer>();
        Knowledge knowledge = new Knowledge();
        for (Node node : variables) {
            varName = node.getName();
            if (varName.indexOf(58) == -1) {
                lag = 0;
            } else {
                tmp = varName.substring(varName.indexOf(58) + 1);
                lag = Integer.parseInt(tmp);
            }
            laglist.add(lag);
        }
        int numLags = (Integer)Collections.max(laglist);
        variables.sort((o1, o2) -> {
            String name1 = this.getNameNoLag(o1);
            String name2 = this.getNameNoLag(o2);
            String prefix1 = LargeScaleSimulation.getPrefix(name1);
            String prefix2 = LargeScaleSimulation.getPrefix(name2);
            int index1 = LargeScaleSimulation.getIndex(name1);
            int index2 = LargeScaleSimulation.getIndex(name2);
            if (LargeScaleSimulation.getLag(o1.getName()) == LargeScaleSimulation.getLag(o2.getName())) {
                if (prefix1.compareTo(prefix2) == 0) {
                    return Integer.compare(index1, index2);
                }
                return prefix1.compareTo(prefix2);
            }
            return LargeScaleSimulation.getLag(o1.getName()) - LargeScaleSimulation.getLag(o2.getName());
        });
        for (Node node : variables) {
            varName = node.getName();
            if (varName.indexOf(58) == -1) {
                lag = 0;
            } else {
                tmp = varName.substring(varName.indexOf(58) + 1);
                lag = Integer.parseInt(tmp);
            }
            knowledge.addToTier(numLags - lag, node.getName());
        }
        return knowledge;
    }

    public static String getPrefix(String s) {
        return s.substring(0, 1);
    }

    public static int getIndex(String s) {
        int y = 0;
        for (int i = s.length() - 1; i >= 0; --i) {
            try {
                y = Integer.parseInt(s.substring(i));
                continue;
            }
            catch (NumberFormatException e) {
                return y;
            }
        }
        throw new IllegalArgumentException("Not integer suffix.");
    }

    public static int getLag(String s) {
        if (s.indexOf(58) == -1) {
            return 0;
        }
        String tmp = s.substring(s.indexOf(58) + 1);
        return Integer.parseInt(tmp);
    }

    public double[][] getUncorrelatedShocks(int sampleSize) {
        NormalDistribution distribution = new NormalDistribution(new Well1024a(++this.seed), 0.0, 1.0);
        UniformRealDistribution varDist = new UniformRealDistribution(this.varLow, this.varHigh);
        int numVars = this.variableNodes.size();
        this.setupModel(numVars);
        double[][] shocks = new double[sampleSize][numVars];
        for (int j = 0; j < numVars; ++j) {
            double sd = FastMath.sqrt(((AbstractRealDistribution)varDist).sample());
            for (int i = 0; i < sampleSize; ++i) {
                double sample = ((AbstractRealDistribution)distribution).sample();
                sample *= sd;
                if (!this.errorsNormal) {
                    sample *= sample;
                }
                shocks[i][j] = sample;
            }
        }
        return shocks;
    }

    public double[][] getSoCalledPoissonShocks(int sampleSize) {
        int numVars = this.variableNodes.size();
        this.setupModel(numVars);
        double[][] shocks = new double[sampleSize][numVars];
        for (int j = 0; j < numVars; ++j) {
            int v = 0;
            for (int i = 0; i < sampleSize; ++i) {
                if (RandomUtil.getInstance().nextDouble() < 0.3) {
                    v = 1 - v;
                }
                shocks[i][j] = (double)v + RandomUtil.getInstance().nextNormal(0.0, 0.1);
            }
        }
        return shocks;
    }

    public void setIncludePositiveCoefs(boolean includePositiveCoefs) {
        this.includePositiveCoefs = includePositiveCoefs;
    }

    public void setIncludeNegativeCoefs(boolean includeNegativeCoefs) {
        this.includeNegativeCoefs = includeNegativeCoefs;
    }

    public void setErrorsNormal(boolean errorsNormal) {
        this.errorsNormal = errorsNormal;
    }

    public double getSelfLoopCoef() {
        return this.selfLoopCoef;
    }

    public void setSelfLoopCoef(double selfLoopCoef) {
        this.selfLoopCoef = selfLoopCoef;
    }
}

