/*
 * Decompiled with CFR 0.152.
 */
package edu.cmu.tetrad.algcomparison.simulation;

import edu.cmu.tetrad.algcomparison.graph.RandomGraph;
import edu.cmu.tetrad.algcomparison.simulation.Simulation;
import edu.cmu.tetrad.annotation.Experimental;
import edu.cmu.tetrad.data.AbstractVariable;
import edu.cmu.tetrad.data.BoxDataSet;
import edu.cmu.tetrad.data.ContinuousVariable;
import edu.cmu.tetrad.data.DataModel;
import edu.cmu.tetrad.data.DataSet;
import edu.cmu.tetrad.data.DataTransforms;
import edu.cmu.tetrad.data.DataType;
import edu.cmu.tetrad.data.DiscreteVariable;
import edu.cmu.tetrad.data.MixedDataBox;
import edu.cmu.tetrad.graph.Edge;
import edu.cmu.tetrad.graph.EdgeListGraph;
import edu.cmu.tetrad.graph.Graph;
import edu.cmu.tetrad.graph.Node;
import edu.cmu.tetrad.graph.Paths;
import edu.cmu.tetrad.util.Parameters;
import edu.cmu.tetrad.util.RandomUtil;
import java.util.ArrayList;
import java.util.HashMap;
import java.util.List;
import java.util.Map;
import org.apache.commons.math3.util.FastMath;

@Experimental
public class LinearSineSimulation
implements Simulation {
    private static final long serialVersionUID = 23L;
    private final RandomGraph randomGraph;
    private List<DataSet> dataSets = new ArrayList<DataSet>();
    private List<Graph> graphs = new ArrayList<Graph>();
    private DataType dataType;
    private List<Node> shuffledOrder;
    private double interceptLow;
    private double interceptHigh = 1.0;
    private double linearLow = 0.5;
    private double linearHigh = 1.0;
    private double varLow = 0.5;
    private double varHigh = 0.5;
    private double betaLow = 1.0;
    private double betaHigh = 3.0;
    private double gammaLow = 0.5;
    private double gammaHigh = 1.5;

    public LinearSineSimulation(RandomGraph graph) {
        this.randomGraph = graph;
    }

    private static Graph makeMixedGraph(Graph g, Map<String, Integer> m) {
        List<Node> nodes = g.getNodes();
        for (int i = 0; i < nodes.size(); ++i) {
            AbstractVariable nNew;
            Node n = nodes.get(i);
            int nL = m.get(n.getName());
            if (nL > 0) {
                nNew = new DiscreteVariable(n.getName(), nL);
                nodes.set(i, nNew);
                continue;
            }
            nNew = new ContinuousVariable(n.getName());
            nodes.set(i, nNew);
        }
        EdgeListGraph outG = new EdgeListGraph(nodes);
        for (Edge e : g.getEdges()) {
            Node n1 = e.getNode1();
            Node n2 = e.getNode2();
            Edge eNew = new Edge(outG.getNode(n1.getName()), outG.getNode(n2.getName()), e.getEndpoint1(), e.getEndpoint2());
            outG.addEdge(eNew);
        }
        return outG;
    }

    @Override
    public void createData(Parameters parameters, boolean newModel) {
        if (parameters.getLong("seed") != -1L) {
            RandomUtil.getInstance().setSeed(parameters.getLong("seed"));
        }
        this.setInterceptLow(parameters.getDouble("interceptLow"));
        this.setInterceptHigh(parameters.getDouble("interceptHigh"));
        this.setLinearLow(parameters.getDouble("linearLow"));
        this.setLinearHigh(parameters.getDouble("linearHigh"));
        this.setVarLow(parameters.getDouble("varLow"));
        this.setVarHigh(parameters.getDouble("varHigh"));
        this.setBetaLow(parameters.getDouble("betaLow"));
        this.setBetaHigh(parameters.getDouble("betaHigh"));
        this.setGammaLow(parameters.getDouble("gammaLow"));
        this.setGammaHigh(parameters.getDouble("gammaHigh"));
        this.dataType = DataType.Continuous;
        this.shuffledOrder = null;
        Graph graph = this.randomGraph.createGraph(parameters);
        this.dataSets = new ArrayList<DataSet>();
        this.graphs = new ArrayList<Graph>();
        for (int i = 0; i < parameters.getInt("numRuns"); ++i) {
            System.out.println("Simulating dataset #" + (i + 1));
            if (parameters.getBoolean("differentGraphs") && i > 0) {
                graph = this.randomGraph.createGraph(parameters);
            }
            this.graphs.add(graph);
            DataSet dataSet = this.simulate(graph, parameters);
            if (parameters.getDouble("probRemoveColumn") > 0.0) {
                double aDouble = parameters.getDouble("probRemoveColumn");
                dataSet = DataTransforms.removeRandomColumns(dataSet, aDouble);
            }
            dataSet.setName("" + (i + 1));
            this.dataSets.add(dataSet);
        }
    }

    @Override
    public Graph getTrueGraph(int index) {
        return this.graphs.get(index);
    }

    @Override
    public DataModel getDataModel(int index) {
        return this.dataSets.get(index);
    }

    @Override
    public String getDescription() {
        return "Linear-sine simulation using " + this.randomGraph.getDescription();
    }

    @Override
    public List<String> getParameters() {
        List<String> parameters = this.randomGraph.getParameters();
        parameters.add("numRuns");
        parameters.add("probRemoveColumn");
        parameters.add("differentGraphs");
        parameters.add("sampleSize");
        parameters.add("interceptLow");
        parameters.add("interceptHigh");
        parameters.add("linearLow");
        parameters.add("linearHigh");
        parameters.add("varLow");
        parameters.add("varHigh");
        parameters.add("betaLow");
        parameters.add("betaHigh");
        parameters.add("gammaLow");
        parameters.add("gammaHigh");
        parameters.add("seed");
        return parameters;
    }

    @Override
    public int getNumDataModels() {
        return this.dataSets.size();
    }

    @Override
    public DataType getDataType() {
        return this.dataType;
    }

    private DataSet simulate(Graph G, Parameters parameters) {
        HashMap<String, Integer> nd = new HashMap<String, Integer>();
        List<Node> nodes = G.getNodes();
        RandomUtil.shuffle(nodes);
        if (this.shuffledOrder == null) {
            ArrayList<Node> shuffledNodes = new ArrayList<Node>(nodes);
            RandomUtil.shuffle(shuffledNodes);
            this.shuffledOrder = shuffledNodes;
        }
        for (int i = 0; i < nodes.size(); ++i) {
            nd.put(this.shuffledOrder.get(i).getName(), 0);
        }
        G = LinearSineSimulation.makeMixedGraph(G, nd);
        nodes = G.getNodes();
        BoxDataSet mixedData = new BoxDataSet(new MixedDataBox(nodes, parameters.getInt("sampleSize")), nodes);
        Paths paths = G.paths();
        List<Node> initialOrder = G.getNodes();
        List<Node> tierOrdering = paths.getValidOrder(initialOrder, true);
        int[] tiers = new int[tierOrdering.size()];
        for (int t = 0; t < tierOrdering.size(); ++t) {
            tiers[t] = nodes.indexOf(tierOrdering.get(t));
        }
        for (int mixedIndex : tiers) {
            ContinuousVariable child = (ContinuousVariable)nodes.get(mixedIndex);
            ArrayList<ContinuousVariable> continuousParents = new ArrayList<ContinuousVariable>();
            for (Node node : G.getParents(child)) {
                continuousParents.add((ContinuousVariable)node);
            }
            HashMap<String, double[]> intercept = new HashMap<String, double[]>();
            HashMap<String, double[]> linear = new HashMap<String, double[]>();
            HashMap<String, double[]> beta = new HashMap<String, double[]>();
            HashMap<String, double[]> gamma = new HashMap<String, double[]>();
            HashMap<String, double[]> bounds = new HashMap<String, double[]>();
            for (int j = 1; j <= continuousParents.size(); ++j) {
                String key = ((ContinuousVariable)continuousParents.get(j - 1)).toString();
                if (bounds.containsKey(key)) continue;
                double m0 = mixedData.getDouble(0, mixedData.getColumn((Node)continuousParents.get(j - 1)));
                double m1 = mixedData.getDouble(0, mixedData.getColumn((Node)continuousParents.get(j - 1)));
                for (int i = 1; i < parameters.getInt("sampleSize"); ++i) {
                    m0 = FastMath.min(m0, mixedData.getDouble(i, mixedData.getColumn((Node)continuousParents.get(j - 1))));
                    m1 = FastMath.max(m1, mixedData.getDouble(i, mixedData.getColumn((Node)continuousParents.get(j - 1))));
                }
                double[] temp = new double[]{m0, (m1 - m0) / 2.0, m1};
                bounds.put(key, temp);
            }
            double mean = 0.0;
            double var = 0.0;
            for (int i = 0; i < parameters.getInt("sampleSize"); ++i) {
                int j;
                double[] parents = new double[continuousParents.size()];
                double value = 0.0;
                String key = "";
                for (int j2 = 1; j2 <= continuousParents.size(); ++j2) {
                    parents[j2 - 1] = mixedData.getDouble(i, mixedData.getColumn((Node)continuousParents.get(j2 - 1)));
                }
                if (!intercept.containsKey("")) {
                    double[] interceptCoefficients = new double[]{(double)this.randSign() * RandomUtil.getInstance().nextUniform(this.interceptLow, this.interceptHigh)};
                    intercept.put("", interceptCoefficients);
                }
                if (!linear.containsKey("") && !continuousParents.isEmpty()) {
                    double[] linearCoefficients = new double[parents.length];
                    for (j = 0; j < parents.length; ++j) {
                        linearCoefficients[j] = (double)this.randSign() * RandomUtil.getInstance().nextUniform(this.linearLow, this.linearHigh);
                    }
                    linear.put("", linearCoefficients);
                }
                if (!beta.containsKey("") && !continuousParents.isEmpty()) {
                    double[] betaCoefficients = new double[parents.length];
                    for (j = 0; j < parents.length; ++j) {
                        betaCoefficients[j] = (double)this.randSign() * RandomUtil.getInstance().nextUniform(this.betaLow, this.betaHigh);
                    }
                    beta.put("", betaCoefficients);
                }
                if (!gamma.containsKey("") && !continuousParents.isEmpty()) {
                    double[] gammaCoefficients = new double[parents.length];
                    for (j = 0; j < parents.length; ++j) {
                        String key2 = ((ContinuousVariable)continuousParents.get(j)).toString();
                        gammaCoefficients[j] = (((double[])bounds.get(key2))[1] - ((double[])bounds.get(key2))[0]) / (Math.PI * 2 * RandomUtil.getInstance().nextUniform(this.gammaLow, this.gammaHigh));
                    }
                    gamma.put("", gammaCoefficients);
                }
                value += ((double[])intercept.get(""))[0];
                if (!continuousParents.isEmpty()) {
                    for (int x = 0; x < parents.length; ++x) {
                        value += ((double[])linear.get(""))[x] * parents[x] + ((double[])beta.get(""))[x] * FastMath.sin(parents[x] / ((double[])gamma.get(""))[x]);
                    }
                }
                mixedData.setDouble(i, mixedIndex, value);
                mean += value;
                var += FastMath.pow(value, 2);
            }
            if (continuousParents.size() == 0) {
                var = 1.0;
            } else {
                var /= (double)mixedData.getNumRows();
                var -= FastMath.pow(mean /= (double)mixedData.getNumRows(), 2);
                var = FastMath.sqrt(var);
            }
            double noiseVar = RandomUtil.getInstance().nextUniform(this.varLow, this.varHigh);
            for (int i = 0; i < parameters.getInt("sampleSize"); ++i) {
                mixedData.setDouble(i, mixedIndex, mixedData.getDouble(i, mixedIndex) + var * RandomUtil.getInstance().nextNormal(0.0, noiseVar));
            }
        }
        return mixedData;
    }

    public void setInterceptLow(double interceptLow) {
        this.interceptLow = interceptLow;
    }

    public void setInterceptHigh(double interceptHigh) {
        this.interceptHigh = interceptHigh;
    }

    public void setLinearLow(double linearLow) {
        this.linearLow = linearLow;
    }

    public void setLinearHigh(double linearHigh) {
        this.linearHigh = linearHigh;
    }

    public void setVarLow(double varLow) {
        this.varLow = varLow;
    }

    public void setVarHigh(double varHigh) {
        this.varHigh = varHigh;
    }

    public void setBetaLow(double betaLow) {
        this.betaLow = betaLow;
    }

    public void setBetaHigh(double betaHigh) {
        this.betaHigh = betaHigh;
    }

    public void setGammaLow(double gammaLow) {
        this.gammaLow = gammaLow;
    }

    public void setGammaHigh(double gammaHigh) {
        this.gammaHigh = gammaHigh;
    }

    private int randSign() {
        return RandomUtil.getInstance().nextInt(2) * 2 - 1;
    }
}

