/*
 * Decompiled with CFR 0.152.
 */
package edu.cmu.tetrad.algcomparison.simulation;

import edu.cmu.tetrad.algcomparison.graph.RandomGraph;
import edu.cmu.tetrad.algcomparison.simulation.Simulation;
import edu.cmu.tetrad.bayes.BayesPm;
import edu.cmu.tetrad.bayes.MlBayesIm;
import edu.cmu.tetrad.data.AbstractVariable;
import edu.cmu.tetrad.data.BoxDataSet;
import edu.cmu.tetrad.data.ContinuousVariable;
import edu.cmu.tetrad.data.DataModel;
import edu.cmu.tetrad.data.DataSet;
import edu.cmu.tetrad.data.DataTransforms;
import edu.cmu.tetrad.data.DataType;
import edu.cmu.tetrad.data.DiscreteVariable;
import edu.cmu.tetrad.data.Discretizer;
import edu.cmu.tetrad.data.MixedDataBox;
import edu.cmu.tetrad.graph.Edge;
import edu.cmu.tetrad.graph.EdgeListGraph;
import edu.cmu.tetrad.graph.Graph;
import edu.cmu.tetrad.graph.Node;
import edu.cmu.tetrad.graph.Paths;
import edu.cmu.tetrad.sem.ParamType;
import edu.cmu.tetrad.sem.Parameter;
import edu.cmu.tetrad.sem.SemPm;
import edu.cmu.tetrad.util.Parameters;
import edu.cmu.tetrad.util.RandomUtil;
import java.util.ArrayList;
import java.util.HashMap;
import java.util.HashSet;
import java.util.List;
import java.util.Map;
import java.util.Set;

public class ConditionalGaussianSimulation
implements Simulation {
    private static final long serialVersionUID = 23L;
    private final RandomGraph randomGraph;
    private List<DataSet> dataSets = new ArrayList<DataSet>();
    private List<Graph> graphs = new ArrayList<Graph>();
    private DataType dataType;
    private List<Node> shuffledOrder;
    private double varLow = 1.0;
    private double varHigh = 3.0;
    private double coefLow = 0.05;
    private double coefHigh = 1.5;
    private boolean coefSymmetric = true;
    private double meanLow = -1.0;
    private double meanHigh = 1.0;

    public ConditionalGaussianSimulation(RandomGraph graph) {
        this.randomGraph = graph;
    }

    private static Graph makeMixedGraph(Graph g, Map<String, Integer> m) {
        List<Node> nodes = g.getNodes();
        for (int i = 0; i < nodes.size(); ++i) {
            Node n = nodes.get(i);
            int nL = m.get(n.getName());
            AbstractVariable nNew = nL > 0 ? new DiscreteVariable(n.getName(), nL) : new ContinuousVariable(n.getName());
            nNew.setNodeType(n.getNodeType());
            nodes.set(i, nNew);
        }
        EdgeListGraph outG = new EdgeListGraph(nodes);
        for (Edge e : g.getEdges()) {
            Node n1 = e.getNode1();
            Node n2 = e.getNode2();
            Edge eNew = new Edge(outG.getNode(n1.getName()), outG.getNode(n2.getName()), e.getEndpoint1(), e.getEndpoint2());
            outG.addEdge(eNew);
        }
        return outG;
    }

    @Override
    public void createData(Parameters parameters, boolean newModel) {
        if (parameters.getLong("seed") != -1L) {
            RandomUtil.getInstance().setSeed(parameters.getLong("seed"));
        }
        this.setVarLow(parameters.getDouble("varLow"));
        this.setVarHigh(parameters.getDouble("varHigh"));
        this.setCoefLow(parameters.getDouble("coefLow"));
        this.setCoefHigh(parameters.getDouble("coefHigh"));
        this.setCoefSymmetric(parameters.getBoolean("covSymmetric"));
        this.setMeanLow(parameters.getDouble("meanLow"));
        this.setMeanHigh(parameters.getDouble("meanHigh"));
        double percentDiscrete = parameters.getDouble("percentDiscrete");
        boolean discrete = parameters.getString("dataType").equals("discrete");
        boolean continuous = parameters.getString("dataType").equals("continuous");
        if (discrete && percentDiscrete != 100.0) {
            throw new IllegalArgumentException("To simulate discrete data, 'percentDiscrete' must be set to 0.0.");
        }
        if (continuous && percentDiscrete != 0.0) {
            throw new IllegalArgumentException("To simulate continuoue data, 'percentDiscrete' must be set to 100.0.");
        }
        if (discrete) {
            this.dataType = DataType.Discrete;
        }
        if (continuous) {
            this.dataType = DataType.Continuous;
        }
        this.shuffledOrder = null;
        Graph graph = this.randomGraph.createGraph(parameters);
        this.dataSets = new ArrayList<DataSet>();
        this.graphs = new ArrayList<Graph>();
        for (int i = 0; i < parameters.getInt("numRuns"); ++i) {
            System.out.println("Simulating dataset #" + (i + 1));
            if (parameters.getBoolean("differentGraphs") && i > 0) {
                graph = this.randomGraph.createGraph(parameters);
            }
            this.graphs.add(graph);
            DataSet dataSet = this.simulate(graph, parameters);
            dataSet.setName("" + (i + 1));
            if (parameters.getBoolean("randomizeColumns")) {
                dataSet = DataTransforms.shuffleColumns(dataSet);
            }
            if (parameters.getDouble("probRemoveColumn") > 0.0) {
                double aDouble = parameters.getDouble("probRemoveColumn");
                dataSet = DataTransforms.removeRandomColumns(dataSet, aDouble);
            }
            this.dataSets.add(dataSet);
        }
    }

    @Override
    public Graph getTrueGraph(int index) {
        return this.graphs.get(index);
    }

    @Override
    public DataModel getDataModel(int index) {
        return this.dataSets.get(index);
    }

    @Override
    public String getDescription() {
        return "Conditional Gaussian simulation using " + this.randomGraph.getDescription();
    }

    @Override
    public List<String> getParameters() {
        List<String> parameters = this.randomGraph.getParameters();
        parameters.add("minCategories");
        parameters.add("maxCategories");
        parameters.add("percentDiscrete");
        parameters.add("numRuns");
        parameters.add("probRemoveColumn");
        parameters.add("differentGraphs");
        parameters.add("sampleSize");
        parameters.add("varLow");
        parameters.add("varHigh");
        parameters.add("coefLow");
        parameters.add("coefHigh");
        parameters.add("covSymmetric");
        parameters.add("meanLow");
        parameters.add("meanHigh");
        parameters.add("saveLatentVars");
        parameters.add("randomizeColumns");
        parameters.add("seed");
        return parameters;
    }

    @Override
    public int getNumDataModels() {
        return this.dataSets.size();
    }

    @Override
    public DataType getDataType() {
        return this.dataType;
    }

    private DataSet simulate(Graph G, Parameters parameters) {
        HashMap<String, Integer> nd = new HashMap<String, Integer>();
        List<Node> nodes = G.getNodes();
        RandomUtil.shuffle(nodes);
        if (this.shuffledOrder == null) {
            ArrayList<Node> shuffledNodes = new ArrayList<Node>(nodes);
            RandomUtil.shuffle(shuffledNodes);
            this.shuffledOrder = shuffledNodes;
        }
        for (int i = 0; i < nodes.size(); ++i) {
            if ((double)i < (double)nodes.size() * parameters.getDouble("percentDiscrete") * 0.01) {
                int minNumCategories = parameters.getInt("minCategories");
                int maxNumCategories = parameters.getInt("maxCategories");
                int value = this.pickNumCategories(minNumCategories, maxNumCategories);
                nd.put(this.shuffledOrder.get(i).getName(), value);
                continue;
            }
            nd.put(this.shuffledOrder.get(i).getName(), 0);
        }
        G = ConditionalGaussianSimulation.makeMixedGraph(G, nd);
        nodes = G.getNodes();
        BoxDataSet mixedData = new BoxDataSet(new MixedDataBox(nodes, parameters.getInt("sampleSize")), nodes);
        ArrayList<Node> X = new ArrayList<Node>();
        ArrayList<Node> A = new ArrayList<Node>();
        for (Node node : G.getNodes()) {
            if (node instanceof ContinuousVariable) {
                X.add(node);
                continue;
            }
            A.add(node);
        }
        Graph AG = G.subgraph(A);
        Graph XG = G.subgraph(X);
        HashMap<ContinuousVariable, DiscreteVariable> erstatzNodes = new HashMap<ContinuousVariable, DiscreteVariable>();
        HashMap<String, ContinuousVariable> erstatzNodesReverse = new HashMap<String, ContinuousVariable>();
        for (Node y : A) {
            for (Node x : G.getParents(y)) {
                if (!(x instanceof ContinuousVariable)) continue;
                DiscreteVariable ersatz = (DiscreteVariable)erstatzNodes.get(x);
                if (ersatz == null) {
                    ersatz = new DiscreteVariable("Ersatz_" + x.getName(), RandomUtil.getInstance().nextInt(3) + 2);
                    erstatzNodes.put((ContinuousVariable)x, ersatz);
                    erstatzNodesReverse.put(ersatz.getName(), (ContinuousVariable)x);
                    AG.addNode(ersatz);
                }
                AG.addDirectedEdge(ersatz, y);
            }
        }
        BayesPm bayesPm = new BayesPm(AG);
        MlBayesIm bayesIm = new MlBayesIm(bayesPm, 1);
        SemPm semPm = new SemPm(XG);
        HashMap<Combination, Double> paramValues = new HashMap<Combination, Double>();
        Paths paths = G.paths();
        List<Node> initialOrder = G.getNodes();
        List<Node> tierOrdering = paths.getValidOrder(initialOrder, true);
        int[] tiers = new int[tierOrdering.size()];
        for (int t = 0; t < tierOrdering.size(); ++t) {
            tiers[t] = nodes.indexOf(tierOrdering.get(t));
        }
        HashMap<Integer, double[]> breakpointsMap = new HashMap<Integer, double[]>();
        for (int mixedIndex : tiers) {
            block6: for (int i = 0; i < parameters.getInt("sampleSize"); ++i) {
                if (nodes.get(mixedIndex) instanceof DiscreteVariable) {
                    int bayesIndex = bayesIm.getNodeIndex(nodes.get(mixedIndex));
                    int[] bayesParents = bayesIm.getParents(bayesIndex);
                    int[] parentValues = new int[bayesParents.length];
                    for (int k = 0; k < parentValues.length; ++k) {
                        int value;
                        int bayesParentColumn = bayesParents[k];
                        Node bayesParent = bayesIm.getVariables().get(bayesParentColumn);
                        DiscreteVariable _parent = (DiscreteVariable)bayesParent;
                        ContinuousVariable orig = (ContinuousVariable)erstatzNodesReverse.get(_parent.getName());
                        if (orig != null) {
                            int mixedParentColumn = mixedData.getColumn(orig);
                            double d = mixedData.getDouble(i, mixedParentColumn);
                            double[] breakpoints = (double[])breakpointsMap.get(mixedParentColumn);
                            if (breakpoints == null) {
                                breakpoints = this.getBreakpoints(mixedData, _parent, mixedParentColumn);
                                breakpointsMap.put(mixedParentColumn, breakpoints);
                            }
                            value = breakpoints.length;
                            for (int j = 0; j < breakpoints.length; ++j) {
                                if (!(d < breakpoints[j])) continue;
                                value = j;
                                break;
                            }
                        } else {
                            int mixedColumn = mixedData.getColumn(bayesParent);
                            value = mixedData.getInt(i, mixedColumn);
                        }
                        parentValues[k] = value;
                    }
                    int rowIndex = bayesIm.getRowIndex(bayesIndex, parentValues);
                    double sum = 0.0;
                    double r = RandomUtil.getInstance().nextDouble();
                    mixedData.setInt(i, mixedIndex, 0);
                    for (int k = 0; k < bayesIm.getNumColumns(bayesIndex); ++k) {
                        double probability = bayesIm.getProbability(bayesIndex, rowIndex, k);
                        if (!((sum += probability) >= r)) continue;
                        mixedData.setInt(i, mixedIndex, k);
                        continue block6;
                    }
                    continue;
                }
                Node y = nodes.get(mixedIndex);
                HashSet<DiscreteVariable> discreteParents = new HashSet<DiscreteVariable>();
                HashSet<ContinuousVariable> continuousParents = new HashSet<ContinuousVariable>();
                for (Node node : G.getParents(y)) {
                    if (node instanceof DiscreteVariable) {
                        discreteParents.add((DiscreteVariable)node);
                        continue;
                    }
                    continuousParents.add((ContinuousVariable)node);
                }
                Parameter varParam = semPm.getParameter(y, y);
                Parameter muParam = semPm.getMeanParameter(y);
                Combination varComb = new Combination(varParam);
                Combination muComb = new Combination(muParam);
                for (DiscreteVariable v : discreteParents) {
                    varComb.addParamValue(v, mixedData.getInt(i, mixedData.getColumn(v)));
                    muComb.addParamValue(v, mixedData.getInt(i, mixedData.getColumn(v)));
                }
                double value = RandomUtil.getInstance().nextNormal(0.0, this.getParamValue(varComb, paramValues));
                for (Node node : continuousParents) {
                    Parameter coefParam = semPm.getParameter(node, y);
                    Combination coefComb = new Combination(coefParam);
                    for (DiscreteVariable v : discreteParents) {
                        coefComb.addParamValue(v, mixedData.getInt(i, mixedData.getColumn(v)));
                    }
                    int parent = nodes.indexOf(node);
                    double parentValue = mixedData.getDouble(i, parent);
                    double parentCoef = this.getParamValue(coefComb, paramValues);
                    value += parentValue * parentCoef;
                }
                mixedData.setDouble(i, mixedIndex, value += this.getParamValue(muComb, paramValues).doubleValue());
            }
        }
        boolean saveLatentVars = parameters.getBoolean("saveLatentVars");
        return saveLatentVars ? mixedData : DataTransforms.restrictToMeasured(mixedData);
    }

    private double[] getBreakpoints(DataSet mixedData, DiscreteVariable _parent, int mixedParentColumn) {
        double[] data = new double[mixedData.getNumRows()];
        for (int r = 0; r < mixedData.getNumRows(); ++r) {
            data[r] = mixedData.getDouble(r, mixedParentColumn);
        }
        return Discretizer.getEqualFrequencyBreakPoints(data, _parent.getNumCategories());
    }

    private Double getParamValue(Combination values, Map<Combination, Double> map) {
        Double d = map.get(values);
        if (d == null) {
            Parameter parameter = values.getParameter();
            if (parameter.getType() == ParamType.VAR) {
                d = RandomUtil.getInstance().nextUniform(this.varLow, this.varHigh);
                map.put(values, d);
            } else if (parameter.getType() == ParamType.COEF) {
                double min = this.coefLow;
                double max = this.coefHigh;
                double value = RandomUtil.getInstance().nextUniform(min, max);
                d = RandomUtil.getInstance().nextUniform(0.0, 1.0) < 0.5 && this.coefSymmetric ? -value : value;
                map.put(values, d);
            } else if (parameter.getType() == ParamType.MEAN) {
                d = RandomUtil.getInstance().nextUniform(this.meanLow, this.meanHigh);
                map.put(values, d);
            }
        }
        return d;
    }

    public void setVarLow(double varLow) {
        this.varLow = varLow;
    }

    public void setVarHigh(double varHigh) {
        this.varHigh = varHigh;
    }

    public void setCoefLow(double coefLow) {
        this.coefLow = coefLow;
    }

    public void setCoefHigh(double coefHigh) {
        this.coefHigh = coefHigh;
    }

    public void setCoefSymmetric(boolean coefSymmetric) {
        this.coefSymmetric = coefSymmetric;
    }

    public void setMeanLow(double meanLow) {
        this.meanLow = meanLow;
    }

    public void setMeanHigh(double meanHigh) {
        this.meanHigh = meanHigh;
    }

    private int pickNumCategories(int min, int max) {
        return min + RandomUtil.getInstance().nextInt(max - min + 1);
    }

    private static class Combination {
        private final Parameter parameter;
        private final Set<VariableValues> paramValues;

        public Combination(Parameter parameter) {
            this.parameter = parameter;
            this.paramValues = new HashSet<VariableValues>();
        }

        public void addParamValue(DiscreteVariable variable, int value) {
            this.paramValues.add(new VariableValues(variable, value));
        }

        public int hashCode() {
            return this.parameter.hashCode() + this.paramValues.hashCode();
        }

        public boolean equals(Object o) {
            if (o == this) {
                return true;
            }
            if (!(o instanceof Combination)) {
                return false;
            }
            Combination v = (Combination)o;
            return v.parameter == this.parameter && v.paramValues.equals(this.paramValues);
        }

        public Parameter getParameter() {
            return this.parameter;
        }
    }

    private static class VariableValues {
        private final DiscreteVariable variable;
        private final int value;

        public VariableValues(DiscreteVariable variable, int value) {
            this.variable = variable;
            this.value = value;
        }

        public DiscreteVariable getVariable() {
            return this.variable;
        }

        public int getValue() {
            return this.value;
        }

        public int hashCode() {
            return this.variable.hashCode() + this.value;
        }

        public boolean equals(Object o) {
            if (o == this) {
                return true;
            }
            if (!(o instanceof VariableValues)) {
                return false;
            }
            VariableValues v = (VariableValues)o;
            return v.variable.equals(this.variable) && v.value == this.value;
        }
    }
}

