/*
 * Decompiled with CFR 0.152.
 */
package edu.cmu.tetrad.algcomparison.simulation;

import edu.cmu.tetrad.algcomparison.graph.RandomGraph;
import edu.cmu.tetrad.algcomparison.graph.SingleGraph;
import edu.cmu.tetrad.algcomparison.simulation.Simulation;
import edu.cmu.tetrad.data.DataModel;
import edu.cmu.tetrad.data.DataSet;
import edu.cmu.tetrad.data.DataTransforms;
import edu.cmu.tetrad.data.DataType;
import edu.cmu.tetrad.graph.Graph;
import edu.cmu.tetrad.graph.Node;
import edu.cmu.tetrad.graph.SemGraph;
import edu.cmu.tetrad.sem.GeneralizedSemIm;
import edu.cmu.tetrad.sem.GeneralizedSemPm;
import edu.cmu.tetrad.sem.TemplateExpander;
import edu.cmu.tetrad.util.Parameters;
import edu.cmu.tetrad.util.RandomUtil;
import java.text.ParseException;
import java.util.ArrayList;
import java.util.List;
import org.apache.commons.math3.util.FastMath;

public class GeneralSemSimulation
implements Simulation {
    private static final long serialVersionUID = 23L;
    private final RandomGraph randomGraph;
    private GeneralizedSemPm pm;
    private GeneralizedSemIm im;
    private List<DataSet> dataSets = new ArrayList<DataSet>();
    private List<Graph> graphs = new ArrayList<Graph>();
    private List<GeneralizedSemIm> ims = new ArrayList<GeneralizedSemIm>();

    public GeneralSemSimulation(RandomGraph graph) {
        this.randomGraph = graph;
    }

    public GeneralSemSimulation(GeneralizedSemPm pm) {
        SemGraph graph = pm.getGraph();
        graph.setShowErrorTerms(false);
        this.randomGraph = new SingleGraph(graph);
        this.pm = pm;
    }

    public GeneralSemSimulation(GeneralizedSemIm im) {
        SemGraph graph = im.getSemPm().getGraph();
        graph.setShowErrorTerms(false);
        this.randomGraph = new SingleGraph(graph);
        this.im = im;
        this.ims = new ArrayList<GeneralizedSemIm>();
        this.ims.add(im);
        this.pm = im.getGeneralizedSemPm();
    }

    @Override
    public void createData(Parameters parameters, boolean newModel) {
        if (parameters.getLong("seed") != -1L) {
            RandomUtil.getInstance().setSeed(parameters.getLong("seed"));
        }
        Graph graph = this.randomGraph.createGraph(parameters);
        this.dataSets = new ArrayList<DataSet>();
        this.graphs = new ArrayList<Graph>();
        this.ims = new ArrayList<GeneralizedSemIm>();
        for (int i = 0; i < parameters.getInt("numRuns"); ++i) {
            double variance;
            System.out.println("Simulating dataset #" + (i + 1));
            if (parameters.getBoolean("differentGraphs") && i > 0) {
                graph = this.randomGraph.createGraph(parameters);
            }
            this.graphs.add(graph);
            DataSet dataSet = this.simulate(graph, parameters);
            if (parameters.getBoolean("standardize")) {
                dataSet = DataTransforms.standardizeData(dataSet);
            }
            if ((variance = parameters.getDouble("measurementVariance")) > 0.0) {
                for (int k = 0; k < dataSet.getNumRows(); ++k) {
                    for (int j = 0; j < dataSet.getNumColumns(); ++j) {
                        double d = dataSet.getDouble(k, j);
                        double norm = RandomUtil.getInstance().nextNormal(0.0, FastMath.sqrt(variance));
                        dataSet.setDouble(k, j, d + norm);
                    }
                }
            }
            if (parameters.getBoolean("randomizeColumns")) {
                dataSet = DataTransforms.shuffleColumns(dataSet);
            }
            if (parameters.getDouble("probRemoveColumn") > 0.0) {
                double aDouble = parameters.getDouble("probRemoveColumn");
                dataSet = DataTransforms.removeRandomColumns(dataSet, aDouble);
            }
            dataSet.setName("" + (i + 1));
            this.dataSets.add(DataTransforms.restrictToMeasured(dataSet));
        }
    }

    private synchronized DataSet simulate(Graph graph, Parameters parameters) {
        if (this.pm == null) {
            this.pm = this.getPm(graph, parameters);
        }
        System.out.println(this.pm);
        this.im = new GeneralizedSemIm(this.pm);
        this.im.setGuaranteeIid(parameters.getBoolean("guaranteeIid"));
        System.out.println(this.im);
        this.ims.add(this.im);
        return this.im.simulateData(parameters.getInt("sampleSize"), true);
    }

    @Override
    public Graph getTrueGraph(int index) {
        return this.graphs.get(index);
    }

    @Override
    public int getNumDataModels() {
        return this.dataSets.size();
    }

    @Override
    public DataModel getDataModel(int index) {
        return this.dataSets.get(index);
    }

    @Override
    public DataType getDataType() {
        return DataType.Continuous;
    }

    @Override
    public String getDescription() {
        return "Nonlinear, non-Gaussian SEM simulation using " + this.randomGraph.getDescription();
    }

    @Override
    public List<String> getParameters() {
        ArrayList<String> parameters = new ArrayList<String>();
        if (!(this.randomGraph instanceof SingleGraph)) {
            parameters.addAll(this.randomGraph.getParameters());
        }
        if (this.pm == null) {
            parameters.addAll(GeneralizedSemPm.getParameterNames());
        }
        parameters.add("numRuns");
        parameters.add("probRemoveColumn");
        parameters.add("differentGraphs");
        parameters.add("sampleSize");
        parameters.add("guaranteeIid");
        parameters.add("seed");
        return parameters;
    }

    private GeneralizedSemPm getPm(Graph graph, Parameters parameters) {
        GeneralizedSemPm pm = new GeneralizedSemPm(graph);
        List<Node> variablesNodes = pm.getVariableNodes();
        List<Node> errorNodes = pm.getErrorNodes();
        try {
            String _template;
            for (Node node : variablesNodes) {
                _template = TemplateExpander.getInstance().expandTemplate(parameters.getString("generalSemFunctionTemplateMeasured"), pm, node);
                pm.setNodeExpression(node, _template);
            }
            for (Node node : errorNodes) {
                _template = TemplateExpander.getInstance().expandTemplate(parameters.getString("generalSemErrorTemplate"), pm, node);
                pm.setNodeExpression(node, _template);
            }
            for (String parameter : pm.getParameters()) {
                pm.setParameterExpression(parameter, parameters.getString("generalSemParameterTemplate"));
            }
            pm.setVariablesTemplate(parameters.getString("generalSemFunctionTemplateMeasured"));
            pm.setErrorsTemplate(parameters.getString("generalSemErrorTemplate"));
            pm.setParametersTemplate(parameters.getString("generalSemParameterTemplate"));
        }
        catch (ParseException e) {
            e.printStackTrace();
        }
        return pm;
    }

    public List<GeneralizedSemIm> getIms() {
        return this.ims;
    }
}

