/*
 * Decompiled with CFR 0.152.
 */
package edu.cmu.tetrad.sem;

import edu.cmu.tetrad.data.BoxDataSet;
import edu.cmu.tetrad.data.ContinuousVariable;
import edu.cmu.tetrad.data.DataSet;
import edu.cmu.tetrad.data.DataTransforms;
import edu.cmu.tetrad.data.DoubleDataBox;
import edu.cmu.tetrad.data.Simulator;
import edu.cmu.tetrad.data.VerticalDoubleDataBox;
import edu.cmu.tetrad.graph.Edge;
import edu.cmu.tetrad.graph.Edges;
import edu.cmu.tetrad.graph.Node;
import edu.cmu.tetrad.graph.NodeType;
import edu.cmu.tetrad.graph.SemGraph;
import edu.cmu.tetrad.sem.ParamType;
import edu.cmu.tetrad.sem.Parameter;
import edu.cmu.tetrad.sem.SemEstimator;
import edu.cmu.tetrad.sem.SemIm;
import edu.cmu.tetrad.sem.SemPm;
import edu.cmu.tetrad.util.ChoiceGenerator;
import edu.cmu.tetrad.util.Matrix;
import edu.cmu.tetrad.util.MatrixUtils;
import edu.cmu.tetrad.util.NumberFormatUtil;
import edu.cmu.tetrad.util.Parameters;
import edu.cmu.tetrad.util.RandomUtil;
import edu.cmu.tetrad.util.TetradSerializable;
import edu.cmu.tetrad.util.Vector;
import java.text.NumberFormat;
import java.util.ArrayList;
import java.util.HashMap;
import java.util.List;
import java.util.Map;
import org.apache.commons.math3.util.FastMath;

public class StandardizedSemIm
implements Simulator {
    private static final long serialVersionUID = 23L;
    private final int sampleSize;
    private final SemPm semPm;
    private final SemGraph semGraph;
    private final Map<Edge, Double> edgeParameters;
    private Matrix edgeCoef;
    private Matrix errorCovar;
    private Map<Node, Double> errorVariances;
    private Matrix implCovar;
    private Matrix implCovarMeas;
    private Edge editingEdge;
    private ParameterRange range;

    public StandardizedSemIm(SemIm im, Parameters parameters) {
        this(im, Initialization.CALCULATE_FROM_SEM, parameters);
    }

    public StandardizedSemIm(SemIm im, Initialization initialization, Parameters parameters) {
        if (im.getSemPm().getGraph().isTimeLagModel()) {
            throw new IllegalArgumentException("Standardized SEM IM with a time lag model with latent variables is not supported.");
        }
        this.semPm = new SemPm(im.getSemPm());
        this.semGraph = new SemGraph(this.semPm.getGraph());
        this.semGraph.setShowErrorTerms(true);
        this.sampleSize = parameters.getInt("sampleSize");
        if (this.semGraph.paths().existsDirectedCycle()) {
            throw new IllegalArgumentException("The cyclic case is not handled.");
        }
        if (initialization == Initialization.CALCULATE_FROM_SEM) {
            this.edgeParameters = new HashMap<Edge, Double>();
            List<Node> nodes = im.getVariableNodes();
            Matrix impliedCovar = im.getImplCovar(true);
            for (Parameter parameter : im.getSemPm().getParameters()) {
                Node b;
                Node a;
                if (parameter.getType() == ParamType.COEF) {
                    a = parameter.getNodeA();
                    b = parameter.getNodeB();
                    int aindex = nodes.indexOf(a);
                    int bindex = nodes.indexOf(b);
                    double vara = impliedCovar.get(aindex, aindex);
                    double stda = FastMath.sqrt(vara);
                    double varb = impliedCovar.get(bindex, bindex);
                    double stdb = FastMath.sqrt(varb);
                    double oldCoef = im.getEdgeCoef(a, b);
                    double newCoef = stda / stdb * oldCoef;
                    this.edgeParameters.put(Edges.directedEdge(a, b), newCoef);
                    continue;
                }
                if (parameter.getType() != ParamType.COVAR) continue;
                a = parameter.getNodeA();
                b = parameter.getNodeB();
                Node exoa = this.semGraph.getExogenous(a);
                Node exob = this.semGraph.getExogenous(b);
                double covar = im.getErrCovar(a, b) / FastMath.sqrt(im.getErrVar(a) * im.getErrVar(b));
                this.edgeParameters.put(Edges.bidirectedEdge(exoa, exob), covar);
            }
        } else {
            DataSet dataSet = im.simulateData(this.sampleSize, false);
            Matrix _dataSet = dataSet.getDoubleData();
            _dataSet = DataTransforms.standardizeData(_dataSet);
            BoxDataSet dataSetStandardized = new BoxDataSet(new VerticalDoubleDataBox(_dataSet.toArray()), dataSet.getVariables());
            SemEstimator estimator = new SemEstimator(dataSetStandardized, im.getSemPm());
            SemIm imStandardized = estimator.estimate();
            this.edgeParameters = new HashMap<Edge, Double>();
            for (Parameter parameter : imStandardized.getSemPm().getParameters()) {
                Node b;
                Node a;
                if (parameter.getType() == ParamType.COEF) {
                    a = parameter.getNodeA();
                    b = parameter.getNodeB();
                    double coef = imStandardized.getEdgeCoef(a, b);
                    this.edgeParameters.put(Edges.directedEdge(a, b), coef);
                    continue;
                }
                if (parameter.getType() != ParamType.COVAR) continue;
                a = parameter.getNodeA();
                b = parameter.getNodeB();
                Node exoa = this.semGraph.getExogenous(a);
                Node exob = this.semGraph.getExogenous(b);
                double covar = -im.getErrCovar(a, b) / FastMath.sqrt(im.getErrVar(a) * im.getErrVar(b));
                this.edgeParameters.put(Edges.bidirectedEdge(exoa, exob), covar);
            }
        }
    }

    public static StandardizedSemIm serializableInstance() {
        return new StandardizedSemIm(SemIm.serializableInstance(), new Parameters());
    }

    public int getSampleSize() {
        return this.sampleSize;
    }

    public boolean containsParameter(Edge edge) {
        if (Edges.isBidirectedEdge(edge)) {
            edge = Edges.bidirectedEdge(this.semGraph.getExogenous(edge.getNode1()), this.semGraph.getExogenous(edge.getNode2()));
        }
        return this.edgeParameters.containsKey(edge);
    }

    public boolean setEdgeCoefficient(Node a, Node b, double coef) {
        Edge edge = Edges.directedEdge(a, b);
        if (this.edgeParameters.get(edge) == null) {
            throw new NullPointerException("Not a coefficient parameter in this model: " + edge);
        }
        if (!edge.equals(this.editingEdge)) {
            this.range = this.getParameterRange(edge);
            this.editingEdge = edge;
        }
        if (coef > this.range.getLow() && coef < this.range.getHigh()) {
            this.edgeParameters.put(edge, coef);
            return true;
        }
        return false;
    }

    public boolean setErrorCovariance(Node a, Node b, double covar) {
        Edge edge = Edges.bidirectedEdge(this.semGraph.getExogenous(a), this.semGraph.getExogenous(b));
        if (this.edgeParameters.get(edge) == null) {
            throw new IllegalArgumentException("Not a covariance parameter in this model: " + edge);
        }
        if (!edge.equals(this.editingEdge)) {
            this.range = this.getParameterRange(edge);
            this.editingEdge = edge;
        }
        if (covar > this.range.getLow() && covar < this.range.getHigh()) {
            this.edgeParameters.put(edge, covar);
            return true;
        }
        return false;
    }

    public double getEdgeCoef(Node a, Node b) {
        Edge edge = Edges.directedEdge(a, b);
        Double d = this.edgeParameters.get(edge);
        if (d == null) {
            return Double.NaN;
        }
        return d;
    }

    public double getErrorCovariance(Node a, Node b) {
        Edge edge = Edges.bidirectedEdge(this.semGraph.getExogenous(a), this.semGraph.getExogenous(b));
        Double d = this.edgeParameters.get(edge);
        if (d == null) {
            throw new IllegalArgumentException("Not a covariance parameter in this model: " + edge);
        }
        return d;
    }

    public double getParameterValue(Edge edge) {
        if (Edges.isDirectedEdge(edge)) {
            return this.getEdgeCoef(edge.getNode1(), edge.getNode2());
        }
        if (Edges.isBidirectedEdge(edge)) {
            return this.getErrorCovariance(edge.getNode1(), edge.getNode2());
        }
        throw new IllegalArgumentException("Only directed and bidirected edges are supported: " + edge);
    }

    public void setParameterValue(Edge edge, double value) {
        if (Edges.isDirectedEdge(edge)) {
            this.setEdgeCoefficient(edge.getNode1(), edge.getNode2(), value);
        } else if (Edges.isBidirectedEdge(edge)) {
            this.setErrorCovariance(edge.getNode1(), edge.getNode2(), value);
        } else {
            throw new IllegalArgumentException("Only directed and bidirected edges are supported: " + edge);
        }
    }

    public ParameterRange getCoefficientRange(Node a, Node b) {
        return this.getParameterRange(Edges.directedEdge(a, b));
    }

    public ParameterRange getCovarianceRange(Node a, Node b) {
        return this.getParameterRange(Edges.bidirectedEdge(this.semGraph.getExogenous(a), this.semGraph.getExogenous(b)));
    }

    public ParameterRange getParameterRange(Edge edge) {
        double rangeLow;
        double low;
        double rangeHigh;
        if (Edges.isBidirectedEdge(edge)) {
            edge = Edges.bidirectedEdge(this.semGraph.getExogenous(edge.getNode1()), this.semGraph.getExogenous(edge.getNode2()));
        }
        if (!this.edgeParameters.containsKey(edge)) {
            throw new IllegalArgumentException("Not an edge in this model: " + edge);
        }
        double initial = this.edgeParameters.get(edge);
        if (initial == Double.NEGATIVE_INFINITY) {
            initial = Double.MIN_VALUE;
        } else if (initial == Double.POSITIVE_INFINITY) {
            initial = Double.MAX_VALUE;
        }
        double value = initial;
        double high = value + 1.0;
        while (this.paramInBounds(edge, high) && (high = value + 2.0 * (high - value)) != Double.POSITIVE_INFINITY) {
        }
        if (high == Double.POSITIVE_INFINITY) {
            rangeHigh = Double.POSITIVE_INFINITY;
        } else {
            low = value;
            while (high - low > 1.0E-10) {
                double midpoint = (high + low) / 2.0;
                if (this.paramInBounds(edge, midpoint)) {
                    low = midpoint;
                    continue;
                }
                high = midpoint;
            }
            rangeHigh = (high + low) / 2.0;
        }
        low = value - 1.0;
        while (this.paramInBounds(edge, low) && (low = value - 2.0 * (value - low)) != Double.NEGATIVE_INFINITY) {
        }
        if (low == Double.NEGATIVE_INFINITY) {
            rangeLow = Double.NEGATIVE_INFINITY;
        } else {
            high = value;
            while (high - low > 1.0E-10) {
                double midpoint = (high + low) / 2.0;
                if (this.paramInBounds(edge, midpoint)) {
                    high = midpoint;
                    continue;
                }
                low = midpoint;
            }
            rangeLow = (high + low) / 2.0;
        }
        if (Edges.isDirectedEdge(edge)) {
            this.edgeParameters.put(edge, initial);
        } else if (Edges.isBidirectedEdge(edge)) {
            this.edgeParameters.put(edge, initial);
        }
        return new ParameterRange(edge, value, rangeLow, rangeHigh);
    }

    public double getErrorVariance(Node error) {
        return this.calculateErrorVarianceFromParams(error);
    }

    private Map<Node, Double> errorVariances() {
        if (this.errorVariances != null) {
            return this.errorVariances;
        }
        HashMap<Node, Double> errorVarances = new HashMap<Node, Double>();
        for (Node error : this.getErrorNodes()) {
            errorVarances.put(error, this.getErrorVariance(error));
        }
        this.errorVariances = errorVarances;
        return errorVarances;
    }

    public String toString() {
        StringBuilder buf = new StringBuilder();
        NumberFormat nf = NumberFormatUtil.getInstance().getNumberFormat();
        buf.append("\nStandardized SEM:");
        buf.append("\n\nEdge coefficients (parameters):\n");
        for (Edge edge : this.edgeParameters.keySet()) {
            if (!Edges.isDirectedEdge(edge)) continue;
            buf.append("\n").append(edge).append(" ").append(nf.format(this.edgeParameters.get(edge)));
        }
        buf.append("\n\nError covariances (parameters):\n");
        for (Edge edge : this.edgeParameters.keySet()) {
            if (!Edges.isBidirectedEdge(edge)) continue;
            buf.append("\n").append(edge).append(" ").append(nf.format(this.edgeParameters.get(edge)));
        }
        buf.append("\n\nError variances (calculated):\n");
        for (Node error : this.getErrorNodes()) {
            double variance = this.getErrorVariance(error);
            buf.append("\n").append(error).append(" ").append(Double.isNaN(variance) ? "Undefined" : nf.format(variance));
        }
        buf.append("\n");
        return buf.toString();
    }

    public List<Node> getVariableNodes() {
        return this.semPm.getVariableNodes();
    }

    private Matrix edgeCoef() {
        if (this.edgeCoef != null) {
            return this.edgeCoef;
        }
        List<Node> variableNodes = this.getVariableNodes();
        Matrix edgeCoef = new Matrix(variableNodes.size(), variableNodes.size());
        for (Edge edge : this.edgeParameters.keySet()) {
            if (Edges.isBidirectedEdge(edge)) continue;
            Node a = edge.getNode1();
            Node b = edge.getNode2();
            int aindex = variableNodes.indexOf(a);
            int bindex = variableNodes.indexOf(b);
            double coef = this.edgeParameters.get(edge);
            edgeCoef.set(aindex, bindex, coef);
        }
        this.edgeCoef = edgeCoef;
        return edgeCoef;
    }

    public double[] means() {
        return new double[this.semPm.getVariableNodes().size()];
    }

    @Override
    public DataSet simulateData(int sampleSize, boolean latentDataSaved) {
        return this.simulateDataReducedForm(sampleSize, latentDataSaved);
    }

    public DataSet simulateDataReducedForm(int sampleSize, boolean latentDataSaved) {
        this.edgeCoef = null;
        this.errorCovar = null;
        this.errorVariances = null;
        int numVars = this.getVariableNodes().size();
        Matrix B = this.edgeCoef().transpose();
        Matrix iMinusBInv = Matrix.identity(B.getNumRows()).minus(B).inverse();
        Matrix sim = new Matrix(sampleSize, numVars);
        for (int row = 0; row < sampleSize; ++row) {
            Vector e = new Vector(this.edgeCoef().getNumColumns());
            for (int i = 0; i < e.size(); ++i) {
                e.set(i, RandomUtil.getInstance().nextNormal(0.0, FastMath.sqrt(this.errCovar(this.errorVariances(), false).get(i, i))));
            }
            Vector sample = iMinusBInv.times(e);
            sim.assignRow(row, sample);
            for (int col = 0; col < sample.size(); ++col) {
                double value = sim.get(row, col);
                sim.set(row, col, value);
            }
        }
        ArrayList<Node> continuousVars = new ArrayList<Node>();
        for (Node node : this.getVariableNodes()) {
            ContinuousVariable var = new ContinuousVariable(node.getName());
            var.setNodeType(node.getNodeType());
            continuousVars.add(var);
        }
        BoxDataSet fullDataSet = new BoxDataSet(new DoubleDataBox(sim.toArray()), continuousVars);
        if (latentDataSaved) {
            return fullDataSet;
        }
        return DataTransforms.restrictToMeasured(fullDataSet);
    }

    public Matrix getImplCovar() {
        return this.implCovar();
    }

    public Matrix getImplCovarMeas() {
        return this.implCovarMeas().copy();
    }

    private Matrix errCovar(Map<Node, Double> errorVariances, boolean recalculate) {
        if (!recalculate && this.errorCovar != null) {
            return this.errorCovar;
        }
        List<Node> variableNodes = this.getVariableNodes();
        ArrayList<Node> errorNodes = new ArrayList<Node>();
        for (Node node : variableNodes) {
            errorNodes.add(this.semGraph.getExogenous(node));
        }
        Matrix errorCovar = new Matrix(errorVariances.size(), errorVariances.size());
        for (int index = 0; index < errorNodes.size(); ++index) {
            Node error = (Node)errorNodes.get(index);
            double variance = this.getErrorVariance(error);
            errorCovar.set(index, index, variance);
        }
        for (int index1 = 0; index1 < errorNodes.size(); ++index1) {
            for (int index2 = 0; index2 < errorNodes.size(); ++index2) {
                Node error2;
                Node error1 = (Node)errorNodes.get(index1);
                Edge edge = this.semGraph.getEdge(error1, error2 = (Node)errorNodes.get(index2));
                if (edge == null || !Edges.isBidirectedEdge(edge)) continue;
                double covariance = this.getErrorCovariance(error1, error2);
                errorCovar.set(index1, index2, covariance);
            }
        }
        this.errorCovar = errorCovar;
        return errorCovar;
    }

    private Matrix implCovar() {
        this.computeImpliedCovar();
        return this.implCovar;
    }

    private Matrix implCovarMeas() {
        this.computeImpliedCovar();
        return this.implCovarMeas;
    }

    private void computeImpliedCovar() {
        Matrix edgeCoefT = this.edgeCoef().transpose();
        this.implCovar = MatrixUtils.impliedCovar(edgeCoefT, this.errCovar(this.errorVariances(), true));
        int size = this.getMeasuredNodes().size();
        this.implCovarMeas = new Matrix(size, size);
        for (int i = 0; i < size; ++i) {
            for (int j = 0; j < size; ++j) {
                Node iNode = this.getMeasuredNodes().get(i);
                Node jNode = this.getMeasuredNodes().get(j);
                int _i = this.getVariableNodes().indexOf(iNode);
                int _j = this.getVariableNodes().indexOf(jNode);
                this.implCovarMeas.set(i, j, this.implCovar.get(_i, _j));
            }
        }
    }

    public List<Node> getMeasuredNodes() {
        return this.getSemPm().getMeasuredNodes();
    }

    public List<Node> getErrorNodes() {
        ArrayList<Node> errorNodes = new ArrayList<Node>();
        for (Node node : this.getVariableNodes()) {
            errorNodes.add(this.semGraph.getExogenous(node));
        }
        return errorNodes;
    }

    public SemPm getSemPm() {
        return new SemPm(this.semPm);
    }

    private boolean paramInBounds(Edge edge, double newValue) {
        this.edgeParameters.put(edge, newValue);
        HashMap<Node, Double> errorVariances = new HashMap<Node, Double>();
        for (Node node : this.semPm.getVariableNodes()) {
            Node error = this.semGraph.getExogenous(node);
            double d2 = this.calculateErrorVarianceFromParams(error);
            if (Double.isNaN(d2)) {
                return false;
            }
            errorVariances.put(error, d2);
        }
        return MatrixUtils.isPositiveDefinite(this.errCovar(errorVariances, true));
    }

    private double calculateErrorVarianceFromParams(Node error) {
        error = this.semGraph.getNode(error.getName());
        Node child = this.semGraph.getChildren(error).iterator().next();
        ArrayList<Node> parents = new ArrayList<Node>(this.semGraph.getParents(child));
        double otherVariance = 0.0;
        for (Node parent : parents) {
            if (parent == error) continue;
            double coef = this.getEdgeCoef(parent, child);
            otherVariance += coef * coef;
        }
        if (parents.size() >= 2) {
            int[] indices;
            ChoiceGenerator gen = new ChoiceGenerator(parents.size(), 2);
            while ((indices = gen.next()) != null) {
                Node node1 = (Node)parents.get(indices[0]);
                Node node2 = (Node)parents.get(indices[1]);
                double coef1 = node1.getNodeType() != NodeType.ERROR ? this.getEdgeCoef(node1, child) : 1.0;
                double coef2 = node2.getNodeType() != NodeType.ERROR ? this.getEdgeCoef(node2, child) : 1.0;
                List<List<Node>> treks = this.semGraph.paths().treksIncludingBidirected(node1, node2);
                double cov = 0.0;
                for (List<Node> trek : treks) {
                    double product = 1.0;
                    for (int i = 1; i < trek.size(); ++i) {
                        Node _node2;
                        Node _node1 = trek.get(i - 1);
                        Edge edge = this.semGraph.getEdge(_node1, _node2 = trek.get(i));
                        double factor = Edges.isBidirectedEdge(edge) ? this.edgeParameters.get(edge) : (!this.edgeParameters.containsKey(edge) ? 1.0 : (this.semGraph.isParentOf(_node1, _node2) ? this.getEdgeCoef(_node1, _node2) : this.getEdgeCoef(_node2, _node1)));
                        product *= factor;
                    }
                    cov += product;
                }
                otherVariance += 2.0 * coef1 * coef2 * cov;
            }
        }
        return 1.0 - otherVariance <= 0.0 ? Double.NaN : 1.0 - otherVariance;
    }

    public static enum Initialization {
        CALCULATE_FROM_SEM,
        INITIALIZE_FROM_DATA;

    }

    public static final class ParameterRange
    implements TetradSerializable {
        private static final long serialVersionUID = 23L;
        private final Edge edge;
        private final double coef;
        private final double low;
        private final double high;

        public ParameterRange(Edge edge, double coef, double low, double high) {
            this.edge = edge;
            this.coef = coef;
            this.low = low;
            this.high = high;
        }

        public static ParameterRange serializableInstance() {
            return new ParameterRange(Edge.serializableInstance(), 1.0, 1.0, 1.0);
        }

        public Edge getEdge() {
            return this.edge;
        }

        public double getCoef() {
            return this.coef;
        }

        public double getLow() {
            return this.low;
        }

        public double getHigh() {
            return this.high;
        }

        public String toString() {
            return "\n\nRange for " + this.edge + "\nCurrent value = " + this.coef + "\nLow end of range = " + this.low + "\nHigh end of range = " + this.high;
        }
    }
}

