/*
 * Decompiled with CFR 0.152.
 */
package edu.cmu.tetrad.sem;

import edu.cmu.tetrad.data.ColtDataSet;
import edu.cmu.tetrad.data.ContinuousVariable;
import edu.cmu.tetrad.data.DataSet;
import edu.cmu.tetrad.data.DataUtils;
import edu.cmu.tetrad.graph.Node;
import edu.cmu.tetrad.graph.NodeType;
import edu.cmu.tetrad.graph.SemGraph;
import edu.cmu.tetrad.sem.GeneralizedSemPm;
import edu.cmu.tetrad.sem.ParamType;
import edu.cmu.tetrad.sem.Parameter;
import edu.cmu.tetrad.sem.SemIm;
import edu.cmu.tetrad.sem.SemPm;
import edu.cmu.tetrad.util.NumberFormatUtil;
import edu.cmu.tetrad.util.TetradSerializable;
import edu.cmu.tetradapp.model.calculator.expression.Context;
import edu.cmu.tetradapp.model.calculator.expression.Expression;
import edu.cmu.tetradapp.model.calculator.parser.ExpressionLexer;
import edu.cmu.tetradapp.model.calculator.parser.Token;
import java.text.NumberFormat;
import java.util.ArrayList;
import java.util.Collections;
import java.util.HashMap;
import java.util.LinkedList;
import java.util.List;
import java.util.Map;
import java.util.Set;
import pal.math.ConjugateDirectionSearch;
import pal.math.MultivariateFunction;

public class GeneralizedSemIm
implements TetradSerializable {
    static final long serialVersionUID = 23L;
    private GeneralizedSemPm pm;
    private Map<String, Double> parameterValues;
    private boolean simulatePositiveDataOnly = false;

    public GeneralizedSemIm(GeneralizedSemPm pm) {
        this.pm = new GeneralizedSemPm(pm);
        this.parameterValues = new HashMap<String, Double>();
        Set<String> parameters = pm.getParameters();
        for (String parameter : parameters) {
            Expression expression = pm.setParameterExpression(parameter);
            Context context = new Context(){

                @Override
                public Double getValue(String var) {
                    return (Double)GeneralizedSemIm.this.parameterValues.get(var);
                }
            };
            double initialValue = expression.evaluate(context);
            this.parameterValues.put(parameter, initialValue);
        }
    }

    public GeneralizedSemIm(GeneralizedSemPm pm, SemIm semIm) {
        this(pm);
        Parameter paramObject;
        SemPm semPm = semIm.getSemPm();
        Set<String> parameters = pm.getParameters();
        for (String parameter : parameters) {
            paramObject = semPm.getParameter(parameter);
            if (paramObject != null) continue;
            return;
        }
        for (String parameter : parameters) {
            paramObject = semPm.getParameter(parameter);
            if (paramObject == null) {
                throw new IllegalArgumentException("Parameter missing from Gaussian SEM IM: " + parameter);
            }
            double value = semIm.getParamValue(paramObject);
            if (paramObject.getType() == ParamType.VAR) {
                value = Math.sqrt(value);
            }
            this.setParameterValue(parameter, value);
        }
    }

    public static GeneralizedSemIm serializableInstance() {
        return new GeneralizedSemIm(GeneralizedSemPm.serializableInstance());
    }

    public GeneralizedSemPm getGeneralizedSemPm() {
        return new GeneralizedSemPm(this.pm);
    }

    public void setParameterValue(String parameter, double value) {
        if (parameter == null) {
            throw new NullPointerException("Parameter not specified.");
        }
        if (!this.parameterValues.keySet().contains(parameter)) {
            throw new IllegalArgumentException("Not a parameter in this model: " + parameter);
        }
        this.parameterValues.put(parameter, value);
    }

    public double getParameterValue(String parameter) {
        if (parameter == null) {
            throw new NullPointerException("Parameter not specified.");
        }
        if (!this.parameterValues.keySet().contains(parameter)) {
            throw new IllegalArgumentException("Not a parameter in this model: " + parameter);
        }
        return this.parameterValues.get(parameter);
    }

    public String getNodeSubstitutedString(Node node) {
        Token token;
        NumberFormat nf = NumberFormatUtil.getInstance().getNumberFormat();
        String expressionString = this.pm.getNodeExpressionString(node);
        ExpressionLexer lexer = new ExpressionLexer(expressionString);
        StringBuilder buf = new StringBuilder();
        while ((token = lexer.nextTokenIncludingWhitespace()) != Token.EOF) {
            Double value;
            String tokenString = lexer.getTokenString();
            if (token == Token.PARAMETER && (value = this.parameterValues.get(tokenString)) != null) {
                buf.append(nf.format(value));
                continue;
            }
            buf.append(tokenString);
        }
        return buf.toString();
    }

    public String getNodeSubstitutedString(Node node, Map<String, Double> substitutedValues) {
        Token token;
        if (node == null) {
            throw new NullPointerException();
        }
        if (substitutedValues == null) {
            throw new NullPointerException();
        }
        NumberFormat nf = NumberFormatUtil.getInstance().getNumberFormat();
        String expressionString = this.pm.getNodeExpressionString(node);
        ExpressionLexer lexer = new ExpressionLexer(expressionString);
        StringBuilder buf = new StringBuilder();
        while ((token = lexer.nextTokenIncludingWhitespace()) != Token.EOF) {
            String tokenString = lexer.getTokenString();
            if (token == Token.PARAMETER) {
                Double value = substitutedValues.get(tokenString);
                if (value == null) {
                    value = this.parameterValues.get(tokenString);
                }
                if (value != null) {
                    buf.append(nf.format(value));
                    continue;
                }
            }
            buf.append(tokenString);
        }
        return buf.toString();
    }

    public String toString() {
        String string;
        ArrayList<String> parameters = new ArrayList<String>(this.pm.getParameters());
        Collections.sort(parameters);
        NumberFormat nf = NumberFormatUtil.getInstance().getNumberFormat();
        StringBuilder buf = new StringBuilder();
        GeneralizedSemPm pm = this.getGeneralizedSemPm();
        buf.append("\nVariable nodes:\n");
        for (Node node : pm.getVariableNodes()) {
            string = this.getNodeSubstitutedString(node);
            buf.append("\n" + node + " = " + string);
        }
        buf.append("\n\nErrors:\n");
        for (Node node : pm.getErrorNodes()) {
            string = this.getNodeSubstitutedString(node);
            buf.append("\n" + node + " ~ " + string);
        }
        buf.append("\n\nParameter values:\n");
        for (String parameter : parameters) {
            double value = this.getParameterValue(parameter);
            buf.append("\n" + parameter + " = " + nf.format(value));
        }
        return buf.toString();
    }

    public DataSet simulateData(int sampleSize, boolean latentDataSaved) {
        return this.simulateDataAvoidInfinity(sampleSize, latentDataSaved);
    }

    public DataSet simulateDataRecursive1(int sampleSize, boolean latentDataSaved) {
        final HashMap<String, Double> variableValues = new HashMap<String, Double>();
        Context context = new Context(){

            @Override
            public Double getValue(String term) {
                Double value = (Double)GeneralizedSemIm.this.parameterValues.get(term);
                if (value != null) {
                    return value;
                }
                value = (Double)variableValues.get(term);
                if (value != null) {
                    return value;
                }
                throw new IllegalArgumentException("No value recorded for '" + term + "'");
            }
        };
        List<Node> variables = this.pm.getNodes();
        LinkedList<Node> continuousVariables = new LinkedList<Node>();
        List<Node> nonErrorVariables = this.pm.getVariableNodes();
        for (Node node : nonErrorVariables) {
            ContinuousVariable var = new ContinuousVariable(node.getName());
            var.setNodeType(node.getNodeType());
            if (var.getNodeType() == NodeType.ERROR) continue;
            continuousVariables.add(var);
        }
        ColtDataSet fullDataSet = new ColtDataSet(sampleSize, continuousVariables);
        SemGraph graph = this.pm.getGraph();
        List<Node> tierOrdering = graph.getFullTierOrdering();
        int[] tierIndices = new int[variables.size()];
        for (int i = 0; i < tierIndices.length; ++i) {
            tierIndices[i] = nonErrorVariables.indexOf(tierOrdering.get(i));
        }
        int[][] _parents = new int[variables.size()][];
        for (int i = 0; i < variables.size(); ++i) {
            Node node = variables.get(i);
            List<Node> parents = graph.getParents(node);
            _parents[i] = new int[parents.size()];
            for (int j = 0; j < parents.size(); ++j) {
                Node _parent = parents.get(j);
                _parents[i][j] = variables.indexOf(_parent);
            }
        }
        block4: for (int row = 0; row < sampleSize; ++row) {
            variableValues.clear();
            for (int tier = 0; tier < tierOrdering.size(); ++tier) {
                Node node = tierOrdering.get(tier);
                Expression expression = this.pm.getNodeExpression(node);
                double value = expression.evaluate(context);
                variableValues.put(node.getName(), value);
                int col = tierIndices[tier];
                if (col == -1) continue;
                if (this.isSimulatePositiveDataOnly() && value < 0.0) {
                    --row;
                    continue block4;
                }
                fullDataSet.setDouble(row, col, value);
            }
        }
        if (latentDataSaved) {
            return fullDataSet;
        }
        return DataUtils.restrictToMeasured(fullDataSet);
    }

    public DataSet simulateDataMinimizeSurface(int sampleSize, boolean latentDataSaved) {
        final HashMap<String, Double> variableValues = new HashMap<String, Double>();
        double func_tolerance = 1.0E-4;
        double param_tolerance = 0.001;
        LinkedList<Node> continuousVariables = new LinkedList<Node>();
        final List<Node> variableNodes = this.pm.getVariableNodes();
        for (Node node : variableNodes) {
            ContinuousVariable var = new ContinuousVariable(node.getName());
            var.setNodeType(node.getNodeType());
            if (var.getNodeType() == NodeType.ERROR) continue;
            continuousVariables.add(var);
        }
        ColtDataSet fullDataSet = new ColtDataSet(sampleSize, continuousVariables);
        final Context context = new Context(){

            @Override
            public Double getValue(String term) {
                Double value = (Double)GeneralizedSemIm.this.parameterValues.get(term);
                if (value != null) {
                    return value;
                }
                value = (Double)variableValues.get(term);
                if (value != null) {
                    return value;
                }
                throw new IllegalArgumentException("No value recorded for '" + term + "'");
            }
        };
        final double[] _metric = new double[1];
        MultivariateFunction function = new MultivariateFunction(){
            double metric;

            @Override
            public double evaluate(double[] doubles) {
                int i;
                for (int i2 = 0; i2 < variableNodes.size(); ++i2) {
                    variableValues.put(((Node)variableNodes.get(i2)).getName(), doubles[i2]);
                }
                double[] image = new double[doubles.length];
                for (i = 0; i < variableNodes.size(); ++i) {
                    Node node = (Node)variableNodes.get(i);
                    Expression expression = GeneralizedSemIm.this.pm.getNodeExpression(node);
                    image[i] = expression.evaluate(context);
                    if (!Double.isNaN(image[i])) continue;
                    throw new IllegalArgumentException("Undefined value for expression " + expression);
                }
                this.metric = 0.0;
                for (i = 0; i < variableNodes.size(); ++i) {
                    double diff = doubles[i] - image[i];
                    this.metric += diff * diff;
                }
                for (i = 0; i < variableNodes.size(); ++i) {
                    variableValues.put(((Node)variableNodes.get(i)).getName(), image[i]);
                }
                _metric[0] = this.metric;
                return this.metric;
            }

            @Override
            public int getNumArguments() {
                return variableNodes.size();
            }

            @Override
            public double getLowerBound(int i) {
                return -10000.0;
            }

            @Override
            public double getUpperBound(int i) {
                return 10000.0;
            }

            public double getMetric() {
                return -this.metric;
            }
        };
        ConjugateDirectionSearch search = new ConjugateDirectionSearch();
        search.step = 10.0;
        block1: for (int row = 0; row < sampleSize; ++row) {
            Node variable;
            int i;
            for (i = 0; i < variableNodes.size(); ++i) {
                variable = variableNodes.get(i);
                Node error = this.pm.getErrorNode(variable);
                Expression expression = this.pm.getNodeExpression(error);
                double value = expression.evaluate(context);
                if (Double.isNaN(value)) {
                    throw new IllegalArgumentException("Undefined value for expression: " + expression);
                }
                variableValues.put(error.getName(), value);
            }
            for (i = 0; i < variableNodes.size(); ++i) {
                variable = variableNodes.get(i);
                variableValues.put(variable.getName(), 0.0);
            }
            do {
                int i2;
                double[] values = new double[variableNodes.size()];
                for (i2 = 0; i2 < values.length; ++i2) {
                    values[i2] = (Double)variableValues.get(variableNodes.get(i2).getName());
                }
                search.optimize(function, values, 1.0E-4, 0.001);
                for (i2 = 0; i2 < variableNodes.size(); ++i2) {
                    if (this.isSimulatePositiveDataOnly() && values[i2] < 0.0) {
                        --row;
                        continue block1;
                    }
                    variableValues.put(variableNodes.get(i2).getName(), values[i2]);
                    fullDataSet.setDouble(row, i2, values[i2]);
                }
            } while (!(_metric[0] < 0.01));
        }
        if (latentDataSaved) {
            return fullDataSet;
        }
        return DataUtils.restrictToMeasured(fullDataSet);
    }

    /*
     * Exception decompiling
     */
    public DataSet simulateDataAvoidInfinity(int sampleSize, boolean latentDataSaved) {
        /*
         * This method has failed to decompile.  When submitting a bug report, please provide this stack trace, and (if you hold appropriate legal rights) the relevant class file.
         * 
         * org.benf.cfr.reader.util.ConfusedCFRException: Tried to end blocks [2[DOLOOP]], but top level block is 8[UNCONDITIONALDOLOOP]
         *     at org.benf.cfr.reader.bytecode.analysis.opgraph.Op04StructuredStatement.processEndingBlocks(Op04StructuredStatement.java:435)
         *     at org.benf.cfr.reader.bytecode.analysis.opgraph.Op04StructuredStatement.buildNestedBlocks(Op04StructuredStatement.java:484)
         *     at org.benf.cfr.reader.bytecode.analysis.opgraph.Op03SimpleStatement.createInitialStructuredBlock(Op03SimpleStatement.java:736)
         *     at org.benf.cfr.reader.bytecode.CodeAnalyser.getAnalysisInner(CodeAnalyser.java:850)
         *     at org.benf.cfr.reader.bytecode.CodeAnalyser.getAnalysisOrWrapFail(CodeAnalyser.java:278)
         *     at org.benf.cfr.reader.bytecode.CodeAnalyser.getAnalysis(CodeAnalyser.java:201)
         *     at org.benf.cfr.reader.entities.attributes.AttributeCode.analyse(AttributeCode.java:94)
         *     at org.benf.cfr.reader.entities.Method.analyse(Method.java:531)
         *     at org.benf.cfr.reader.entities.ClassFile.analyseMid(ClassFile.java:1055)
         *     at org.benf.cfr.reader.entities.ClassFile.analyseTop(ClassFile.java:942)
         *     at org.benf.cfr.reader.Driver.doJarVersionTypes(Driver.java:257)
         *     at org.benf.cfr.reader.Driver.doJar(Driver.java:139)
         *     at org.benf.cfr.reader.CfrDriverImpl.analyse(CfrDriverImpl.java:76)
         *     at org.benf.cfr.reader.Main.main(Main.java:54)
         */
        throw new IllegalStateException("Decompilation failed");
    }

    public DataSet simulateDataNSteps(int sampleSize, boolean latentDataSaved) {
        final HashMap<String, Double> variableValues = new HashMap<String, Double>();
        LinkedList<Node> continuousVariables = new LinkedList<Node>();
        List<Node> variableNodes = this.pm.getVariableNodes();
        for (Node node : variableNodes) {
            ContinuousVariable var = new ContinuousVariable(node.getName());
            var.setNodeType(node.getNodeType());
            if (var.getNodeType() == NodeType.ERROR) continue;
            continuousVariables.add(var);
        }
        ColtDataSet fullDataSet = new ColtDataSet(sampleSize, continuousVariables);
        Context context = new Context(){

            @Override
            public Double getValue(String term) {
                Double value = (Double)GeneralizedSemIm.this.parameterValues.get(term);
                if (value != null) {
                    return value;
                }
                value = (Double)variableValues.get(term);
                if (value != null) {
                    return value;
                }
                throw new IllegalArgumentException("No value recorded for '" + term + "'");
            }
        };
        block1: for (int row = 0; row < sampleSize; ++row) {
            for (Node variable : variableNodes) {
                Node error = this.pm.getErrorNode(variable);
                Expression expression = this.pm.getNodeExpression(error);
                double value = expression.evaluate(context);
                if (Double.isNaN(value)) {
                    throw new IllegalArgumentException("Undefined value for expression: " + expression);
                }
                variableValues.put(error.getName(), value);
            }
            for (Node variable : variableNodes) {
                variableValues.put(variable.getName(), 0.0);
            }
            for (int m = 0; m < 1; ++m) {
                double[] values = new double[variableNodes.size()];
                for (int i = 0; i < values.length; ++i) {
                    Node node = variableNodes.get(i);
                    Expression expression = this.pm.getNodeExpression(node);
                    double value = expression.evaluate(context);
                    if (Double.isNaN(value)) {
                        throw new IllegalArgumentException("Undefined value for expression: " + expression);
                    }
                    values[i] = value;
                }
                for (double value : values) {
                    if (value != Double.POSITIVE_INFINITY && value != Double.NEGATIVE_INFINITY) continue;
                    --row;
                    continue block1;
                }
                for (int i = 0; i < variableNodes.size(); ++i) {
                    variableValues.put(variableNodes.get(i).getName(), values[i]);
                }
            }
            for (int i = 0; i < variableNodes.size(); ++i) {
                double value = (Double)variableValues.get(variableNodes.get(i).getName());
                fullDataSet.setDouble(row, i, value);
            }
        }
        if (latentDataSaved) {
            return fullDataSet;
        }
        return DataUtils.restrictToMeasured(fullDataSet);
    }

    public GeneralizedSemPm getSemPm() {
        return new GeneralizedSemPm(this.pm);
    }

    public void setSubstitutions(Map<String, Double> parameterValues) {
        for (String parameter : parameterValues.keySet()) {
            if (!this.parameterValues.keySet().contains(parameter)) continue;
            this.parameterValues.put(parameter, parameterValues.get(parameter));
        }
    }

    public boolean isSimulatePositiveDataOnly() {
        return this.simulatePositiveDataOnly;
    }

    public void setSimulatePositiveDataOnly(boolean simulatedPositiveDataOnly) {
        this.simulatePositiveDataOnly = simulatedPositiveDataOnly;
    }
}

