/*
 * Decompiled with CFR 0.152.
 */
package edu.cmu.tetrad.algcomparison.simulation;

import edu.cmu.tetrad.algcomparison.graph.RandomGraph;
import edu.cmu.tetrad.algcomparison.graph.SingleGraph;
import edu.cmu.tetrad.algcomparison.simulation.Simulation;
import edu.cmu.tetrad.data.BoxDataSet;
import edu.cmu.tetrad.data.ContinuousVariable;
import edu.cmu.tetrad.data.DataModel;
import edu.cmu.tetrad.data.DataSet;
import edu.cmu.tetrad.data.DataTransforms;
import edu.cmu.tetrad.data.DataType;
import edu.cmu.tetrad.data.DoubleDataBox;
import edu.cmu.tetrad.graph.Graph;
import edu.cmu.tetrad.graph.Node;
import edu.cmu.tetrad.util.Parameters;
import edu.cmu.tetrad.util.RandomUtil;
import edu.cmu.tetrad.util.StatUtils;
import java.util.ArrayList;
import java.util.HashMap;
import java.util.List;
import org.apache.commons.math3.linear.BlockRealMatrix;
import org.apache.commons.math3.linear.MatrixUtils;
import org.apache.commons.math3.linear.RealMatrix;
import org.apache.commons.math3.linear.SingularValueDecomposition;
import org.apache.commons.math3.util.FastMath;

public class NLSemSimulation
implements Simulation {
    private static final long serialVersionUID = 23L;
    private final RandomGraph randomGraph;
    private List<DataSet> dataSets = new ArrayList<DataSet>();
    private List<Graph> graphs = new ArrayList<Graph>();

    public NLSemSimulation(RandomGraph graph) {
        this.randomGraph = graph;
    }

    @Override
    public void createData(Parameters parameters, boolean newModel) {
        if (parameters.getLong("seed") != -1L) {
            RandomUtil.getInstance().setSeed(parameters.getLong("seed"));
        }
        Graph graph = this.randomGraph.createGraph(parameters);
        List<Node> variables = graph.getNodes();
        this.dataSets = new ArrayList<DataSet>();
        this.graphs = new ArrayList<Graph>();
        int sampleSize = parameters.getInt("sampleSize");
        int numVars = parameters.getInt("numMeasures");
        for (int i = 0; i < parameters.getInt("numRuns"); ++i) {
            System.out.println("Simulating dataset #" + (i + 1));
            if (parameters.getBoolean("differentGraphs") && i > 0) {
                graph = this.randomGraph.createGraph(parameters);
            }
            HashMap<Node, Integer> indices = new HashMap<Node, Integer>();
            for (int j = 0; j < numVars; ++j) {
                indices.put(variables.get(j), j);
            }
            this.graphs.add(graph);
            BlockRealMatrix data = new BlockRealMatrix(sampleSize, numVars);
            int errorType = parameters.getInt("simulationErrorType");
            for (int k = 0; k < numVars; ++k) {
                int j;
                double low;
                Node x = variables.get(k);
                List<Node> Pa = graph.getParents(x);
                if (errorType == 1) {
                    low = parameters.getDouble("varLow");
                    double high = parameters.getDouble("varHigh");
                    double std = FastMath.sqrt(RandomUtil.getInstance().nextUniform(low, high));
                    for (int j2 = 0; j2 < sampleSize; ++j2) {
                        data.setEntry(j2, k, RandomUtil.getInstance().nextNormal(0.0, std));
                    }
                } else if (errorType == 2) {
                    low = parameters.getDouble("simulationParam1");
                    double high = parameters.getDouble("simulationParam1");
                    for (int j3 = 0; j3 < sampleSize; ++j3) {
                        data.setEntry(j3, k, RandomUtil.getInstance().nextUniform(low, high));
                    }
                } else if (errorType == 3) {
                    double lambda = parameters.getDouble("simulationParam1");
                    for (int j4 = 0; j4 < sampleSize; ++j4) {
                        data.setEntry(j4, k, RandomUtil.getInstance().nextExponential(lambda));
                    }
                } else if (errorType == 4) {
                    double mu = parameters.getDouble("simulationParam1");
                    double beta = parameters.getDouble("simulationParam2");
                    for (int j5 = 0; j5 < sampleSize; ++j5) {
                        data.setEntry(j5, k, RandomUtil.getInstance().nextGumbel(mu, beta));
                    }
                } else if (errorType == 5) {
                    double shape = parameters.getDouble("simulationParam1");
                    double scale = parameters.getDouble("simulationParam2");
                    for (int j6 = 0; j6 < sampleSize; ++j6) {
                        data.setEntry(j6, k, RandomUtil.getInstance().nextGamma(shape, scale));
                    }
                }
                if (Pa.isEmpty()) continue;
                low = parameters.getDouble("coefLow");
                double high = parameters.getDouble("coefHigh");
                double beta = RandomUtil.getInstance().nextUniform(low, high);
                double[] mu = new double[sampleSize];
                RealMatrix kernel = MatrixUtils.createRealMatrix(sampleSize, sampleSize);
                RealMatrix cov = MatrixUtils.createRealMatrix(sampleSize, sampleSize);
                for (Node z : Pa) {
                    int w = (Integer)indices.get(z);
                    for (j = 0; j < sampleSize; ++j) {
                        mu[j] = beta * data.getEntry(j, w);
                        for (int l = 0; l < sampleSize; ++l) {
                            kernel.addToEntry(j, l, FastMath.pow(data.getEntry(j, w) - data.getEntry(l, w), 2) / (double)Pa.size());
                        }
                    }
                }
                for (int j7 = 0; j7 < sampleSize; ++j7) {
                    for (int l = 0; l < sampleSize; ++l) {
                        cov.setEntry(j7, l, FastMath.exp(-1.0 * kernel.getEntry(j7, l)));
                    }
                }
                SingularValueDecomposition svd = new SingularValueDecomposition(cov);
                RealMatrix S = svd.getS();
                RealMatrix N = MatrixUtils.createRealMatrix(sampleSize, 1);
                for (j = 0; j < sampleSize; ++j) {
                    S.setEntry(j, j, FastMath.sqrt(S.getEntry(j, j)));
                    N.setEntry(j, 0, RandomUtil.getInstance().nextNormal(0.0, 1.0));
                }
                double[] X = svd.getU().multiply(S).multiply(N).getColumn(0);
                for (int j8 = 0; j8 < sampleSize; ++j8) {
                    data.addToEntry(j8, k, mu[j8] + X[j8]);
                }
                data.setColumn(k, StatUtils.standardizeData(data.getColumn(k)));
            }
            ArrayList<Node> continuousVars = new ArrayList<Node>();
            for (Node x : variables) {
                ContinuousVariable var = new ContinuousVariable(x.getName());
                var.setNodeType(x.getNodeType());
                continuousVars.add(var);
            }
            DataSet dataSet = new BoxDataSet(new DoubleDataBox(data.getData()), continuousVars);
            if (parameters.getBoolean("randomizeColumns")) {
                dataSet = DataTransforms.shuffleColumns(dataSet);
            }
            dataSet.setName(String.valueOf(i + 1));
            this.dataSets.add(dataSet);
        }
    }

    @Override
    public DataModel getDataModel(int index) {
        return this.dataSets.get(index);
    }

    @Override
    public Graph getTrueGraph(int index) {
        return this.graphs.get(index);
    }

    @Override
    public String getDescription() {
        return "Non-Linear, SEM simulation using " + this.randomGraph.getDescription();
    }

    @Override
    public List<String> getParameters() {
        ArrayList<String> parameters = new ArrayList<String>();
        if (!(this.randomGraph instanceof SingleGraph)) {
            parameters.addAll(this.randomGraph.getParameters());
        }
        parameters.add("numRuns");
        parameters.add("differentGraphs");
        parameters.add("randomizeColumns");
        parameters.add("sampleSize");
        parameters.add("coefLow");
        parameters.add("coefHigh");
        parameters.add("simulationErrorType");
        parameters.add("varLow");
        parameters.add("varHigh");
        parameters.add("simulationParam1");
        parameters.add("simulationParam2");
        parameters.add("seed");
        return parameters;
    }

    @Override
    public int getNumDataModels() {
        return this.dataSets.size();
    }

    @Override
    public DataType getDataType() {
        return DataType.Continuous;
    }
}

