/*
 * Decompiled with CFR 0.152.
 */
package edu.cmu.tetrad.sem;

import cern.colt.matrix.DoubleMatrix2D;
import cern.colt.matrix.impl.SparseDoubleMatrix2D;
import cern.colt.matrix.linalg.Algebra;
import edu.cmu.tetrad.data.ColtDataSet;
import edu.cmu.tetrad.data.ContinuousVariable;
import edu.cmu.tetrad.data.DataSet;
import edu.cmu.tetrad.graph.Edge;
import edu.cmu.tetrad.graph.Edges;
import edu.cmu.tetrad.graph.Graph;
import edu.cmu.tetrad.graph.Node;
import edu.cmu.tetrad.graph.NodeType;
import edu.cmu.tetrad.util.RandomUtil;
import edu.cmu.tetrad.util.dist.Normal;
import edu.cmu.tetrad.util.dist.Split;
import edu.cmu.tetrad.util.dist.Uniform;
import java.util.Iterator;
import java.util.LinkedList;
import java.util.List;

public final class LargeSemSimulator {
    static final long serialVersionUID = 23L;
    private DoubleMatrix2D edgeCoef;
    private DoubleMatrix2D errCovar;
    private double[] variableMeans;
    private transient Algebra algebra;
    private List<Node> variableNodes;
    private Graph graph;

    public LargeSemSimulator(Graph graph) {
        if (graph == null) {
            throw new NullPointerException("Graph must not be null.");
        }
        this.graph = graph;
        this.variableNodes = graph.getNodes();
        int size = this.variableNodes.size();
        this.edgeCoef = new SparseDoubleMatrix2D(size, size);
        this.errCovar = new SparseDoubleMatrix2D(size, size);
        this.variableMeans = new double[size];
        Split edgeCoefDist = new Split(0.5, 1.5);
        Uniform errorCovarDist = new Uniform(1.0, 3.0);
        Normal meanDist = new Normal(-1.0, 1.0);
        for (Edge edge : graph.getEdges()) {
            if (edge.getNode1().getNodeType() == NodeType.ERROR || edge.getNode2().getNodeType() == NodeType.ERROR) continue;
            Node tail = Edges.getDirectedEdgeTail(edge);
            Node head = Edges.getDirectedEdgeHead(edge);
            int _tail = this.variableNodes.indexOf(tail);
            int _head = this.variableNodes.indexOf(head);
            this.edgeCoef.set(_tail, _head, edgeCoefDist.nextRandom());
        }
        for (int i = 0; i < size; ++i) {
            this.errCovar.set(i, i, errorCovarDist.nextRandom());
            this.variableMeans[i] = meanDist.nextRandom();
        }
    }

    public DataSet simulateDataAcyclic(int sampleSize) {
        LinkedList<Node> variables = new LinkedList<Node>();
        List<Node> variableNodes = this.getVariableNodes();
        for (Node node : variableNodes) {
            ContinuousVariable var = new ContinuousVariable(node.getName());
            variables.add(var);
        }
        ColtDataSet dataSet = new ColtDataSet(sampleSize, variables);
        this.constructSimulation(variableNodes, variables, sampleSize, dataSet);
        return dataSet;
    }

    public DataSet simulateDataAcyclic(DataSet dataSet) {
        LinkedList<Node> variables = new LinkedList<Node>();
        List<Node> variableNodes = this.getVariableNodes();
        for (int i = 0; i < dataSet.getNumColumns(); ++i) {
            ContinuousVariable var = (ContinuousVariable)dataSet.getVariable(i);
            variables.add(var);
        }
        this.constructSimulation(variableNodes, variables, dataSet.getNumRows(), dataSet);
        return dataSet;
    }

    private void constructSimulation(List<Node> variableNodes, List<Node> variables, int sampleSize, DataSet dataSet) {
        Graph graph = this.getGraph();
        List<Node> tierOrdering = graph.getTierOrdering();
        int[] tierIndices = new int[variableNodes.size()];
        for (int i = 0; i < tierIndices.length; ++i) {
            tierIndices[i] = variableNodes.indexOf(tierOrdering.get(i));
        }
        int[][] _parents = new int[variables.size()][];
        for (int i = 0; i < variableNodes.size(); ++i) {
            Node node = variableNodes.get(i);
            List<Node> parents = graph.getParents(node);
            Iterator<Node> j = parents.iterator();
            while (j.hasNext()) {
                Node _node = j.next();
                if (_node.getNodeType() != NodeType.ERROR) continue;
                j.remove();
            }
            _parents[i] = new int[parents.size()];
            for (int j2 = 0; j2 < parents.size(); ++j2) {
                Node _parent = parents.get(j2);
                _parents[i][j2] = variableNodes.indexOf(_parent);
            }
        }
        for (int row = 0; row < sampleSize; ++row) {
            for (int i = 0; i < tierOrdering.size(); ++i) {
                int col = tierIndices[i];
                double value = RandomUtil.getInstance().nextNormal(0.0, 1.0) * this.errCovar.get(col, col);
                for (int j = 0; j < _parents[col].length; ++j) {
                    int parent = _parents[col][j];
                    value += dataSet.getDouble(row, parent) * this.edgeCoef.get(parent, col);
                }
                dataSet.setDouble(row, col, value += this.variableMeans[col]);
            }
        }
    }

    public Algebra getAlgebra() {
        if (this.algebra == null) {
            this.algebra = new Algebra();
        }
        return this.algebra;
    }

    private List<Node> getVariableNodes() {
        return this.variableNodes;
    }

    public Graph getGraph() {
        return this.graph;
    }
}

