/*
 * Decompiled with CFR 0.152.
 */
package edu.cmu.tetrad.sem;

import cern.colt.matrix.DoubleFactory2D;
import cern.colt.matrix.DoubleMatrix1D;
import cern.colt.matrix.DoubleMatrix2D;
import cern.colt.matrix.impl.DenseDoubleMatrix1D;
import cern.colt.matrix.impl.DenseDoubleMatrix2D;
import cern.colt.matrix.linalg.Algebra;
import cern.jet.math.Functions;
import edu.cmu.tetrad.data.ColtDataSet;
import edu.cmu.tetrad.data.DataSet;
import edu.cmu.tetrad.data.DataUtils;
import edu.cmu.tetrad.graph.Edge;
import edu.cmu.tetrad.graph.Edges;
import edu.cmu.tetrad.graph.GraphUtils;
import edu.cmu.tetrad.graph.Node;
import edu.cmu.tetrad.graph.NodeType;
import edu.cmu.tetrad.graph.SemGraph;
import edu.cmu.tetrad.sem.ParamType;
import edu.cmu.tetrad.sem.Parameter;
import edu.cmu.tetrad.sem.SemEstimator;
import edu.cmu.tetrad.sem.SemIm;
import edu.cmu.tetrad.sem.SemPm;
import edu.cmu.tetrad.util.ChoiceGenerator;
import edu.cmu.tetrad.util.MatrixUtils;
import edu.cmu.tetrad.util.NumberFormatUtil;
import edu.cmu.tetrad.util.RandomUtil;
import edu.cmu.tetrad.util.TetradSerializable;
import java.text.NumberFormat;
import java.util.ArrayList;
import java.util.Collections;
import java.util.HashMap;
import java.util.List;
import java.util.Map;

public class StandardizedSemIm
implements TetradSerializable {
    static final long serialVersionUID = 23L;
    private SemPm semPm;
    private SemGraph semGraph;
    private Map<Edge, Double> edgeParameters;
    private DoubleMatrix2D implCovar;
    private DenseDoubleMatrix2D implCovarMeas;
    private List<Node> measuredNodes;
    private Edge editingEdge;
    private ParameterRange range;

    public StandardizedSemIm(SemIm im) {
        this(im, Initialization.CALCULATE_FROM_SEM);
    }

    public static StandardizedSemIm serializableInstance() {
        return new StandardizedSemIm(SemIm.serializableInstance());
    }

    public StandardizedSemIm(SemIm im, Initialization initialization) {
        this.semPm = new SemPm(im.getSemPm());
        this.semGraph = new SemGraph(this.semPm.getGraph());
        this.semGraph.setShowErrorTerms(true);
        if (this.semGraph.existsDirectedCycle()) {
            throw new IllegalArgumentException("The cyclic case is not handled.");
        }
        if (initialization == Initialization.CALCULATE_FROM_SEM) {
            this.edgeParameters = new HashMap<Edge, Double>();
            List<Node> nodes = im.getVariableNodes();
            DoubleMatrix2D impliedCovar = im.getImplCovar();
            for (Parameter parameter : im.getSemPm().getParameters()) {
                Node b;
                Node a;
                if (parameter.getType() == ParamType.COEF) {
                    a = parameter.getNodeA();
                    b = parameter.getNodeB();
                    int aindex = nodes.indexOf(a);
                    int bindex = nodes.indexOf(b);
                    double impliedStdA = Math.sqrt(impliedCovar.get(aindex, aindex));
                    double impliedStdB = Math.sqrt(impliedCovar.get(bindex, bindex));
                    double oldCoef = im.getEdgeCoef(a, b);
                    double newCoef = impliedStdA * oldCoef / impliedStdB;
                    this.edgeParameters.put(Edges.directedEdge(a, b), newCoef);
                    continue;
                }
                if (parameter.getType() != ParamType.COVAR) continue;
                a = parameter.getNodeA();
                b = parameter.getNodeB();
                Node exoa = this.semGraph.getExogenous(a);
                Node exob = this.semGraph.getExogenous(b);
                double covar = im.getErrCovar(a, b) / Math.sqrt(im.getErrVar(a) * im.getErrVar(b));
                this.edgeParameters.put(Edges.bidirectedEdge(exoa, exob), covar);
            }
        } else {
            DataSet dataSet = im.simulateData(1000, false);
            DoubleMatrix2D _dataSet = dataSet.getDoubleData();
            _dataSet = DataUtils.standardizeData(_dataSet);
            ColtDataSet dataSetStandardized = ColtDataSet.makeData(dataSet.getVariables(), _dataSet);
            SemEstimator estimator = new SemEstimator(dataSetStandardized, im.getSemPm());
            SemIm imStandardized = estimator.estimate();
            this.edgeParameters = new HashMap<Edge, Double>();
            for (Parameter parameter : imStandardized.getSemPm().getParameters()) {
                Node b;
                Node a;
                if (parameter.getType() == ParamType.COEF) {
                    a = parameter.getNodeA();
                    b = parameter.getNodeB();
                    double coef = imStandardized.getEdgeCoef(a, b);
                    this.edgeParameters.put(Edges.directedEdge(a, b), coef);
                    continue;
                }
                if (parameter.getType() != ParamType.COVAR) continue;
                a = parameter.getNodeA();
                b = parameter.getNodeB();
                Node exoa = this.semGraph.getExogenous(a);
                Node exob = this.semGraph.getExogenous(b);
                double covar = -im.getErrCovar(a, b) / Math.sqrt(im.getErrVar(a) * im.getErrVar(b));
                this.edgeParameters.put(Edges.bidirectedEdge(exoa, exob), covar);
            }
        }
        this.measuredNodes = Collections.unmodifiableList(this.semPm.getMeasuredNodes());
    }

    public boolean containsParameter(Edge edge) {
        if (Edges.isBidirectedEdge(edge)) {
            edge = Edges.bidirectedEdge(this.semGraph.getExogenous(edge.getNode1()), this.semGraph.getExogenous(edge.getNode2()));
        }
        return this.edgeParameters.keySet().contains(edge);
    }

    public boolean setEdgeCoefficient(Node a, Node b, double coef) {
        Edge edge = Edges.directedEdge(a, b);
        if (this.edgeParameters.get(edge) == null) {
            throw new NullPointerException("Not a coefficient parameter in this model: " + edge);
        }
        if (this.editingEdge == null || !edge.equals(this.editingEdge)) {
            this.range = this.getParameterRange(edge);
            this.editingEdge = edge;
        }
        if (coef > this.range.getLow() && coef < this.range.getHigh()) {
            this.edgeParameters.put(edge, coef);
            return true;
        }
        return false;
    }

    public boolean setErrorCovariance(Node a, Node b, double covar) {
        Edge edge = Edges.bidirectedEdge(this.semGraph.getExogenous(a), this.semGraph.getExogenous(b));
        if (this.edgeParameters.get(edge) == null) {
            throw new IllegalArgumentException("Not a covariance parameter in this model: " + edge);
        }
        if (this.editingEdge == null || !edge.equals(this.editingEdge)) {
            this.range = this.getParameterRange(edge);
            this.editingEdge = edge;
        }
        if (covar > this.range.getLow() && covar < this.range.getHigh()) {
            this.edgeParameters.put(edge, covar);
            return true;
        }
        return false;
    }

    public double getEdgeCoefficient(Node a, Node b) {
        Edge edge = Edges.directedEdge(a, b);
        Double d = this.edgeParameters.get(edge);
        if (d == null) {
            throw new IllegalArgumentException("Not a directed edge in this model: " + edge);
        }
        return d;
    }

    public double getErrorCovariance(Node a, Node b) {
        Edge edge = Edges.bidirectedEdge(this.semGraph.getExogenous(a), this.semGraph.getExogenous(b));
        Double d = this.edgeParameters.get(edge);
        if (d == null) {
            throw new IllegalArgumentException("Not a covariance parameter in this model: " + edge);
        }
        return d;
    }

    public double getParameterValue(Edge edge) {
        if (Edges.isDirectedEdge(edge)) {
            return this.getEdgeCoefficient(edge.getNode1(), edge.getNode2());
        }
        if (Edges.isBidirectedEdge(edge)) {
            return this.getErrorCovariance(edge.getNode1(), edge.getNode2());
        }
        throw new IllegalArgumentException("Only directed and bidirected edges are supported: " + edge);
    }

    public void setParameterValue(Edge edge, double value) {
        if (Edges.isDirectedEdge(edge)) {
            this.setEdgeCoefficient(edge.getNode1(), edge.getNode2(), value);
        } else if (Edges.isBidirectedEdge(edge)) {
            this.setErrorCovariance(edge.getNode1(), edge.getNode2(), value);
        } else {
            throw new IllegalArgumentException("Only directed and bidirected edges are supported: " + edge);
        }
    }

    public ParameterRange getCoefficientRange(Node a, Node b) {
        return this.getParameterRange(Edges.directedEdge(a, b));
    }

    public ParameterRange getCovarianceRange(Node a, Node b) {
        return this.getParameterRange(Edges.bidirectedEdge(this.semGraph.getExogenous(a), this.semGraph.getExogenous(b)));
    }

    public ParameterRange getParameterRange(Edge edge) {
        double rangeLow;
        double low;
        double rangeHigh;
        if (Edges.isBidirectedEdge(edge)) {
            edge = Edges.bidirectedEdge(this.semGraph.getExogenous(edge.getNode1()), this.semGraph.getExogenous(edge.getNode2()));
        }
        if (!this.edgeParameters.keySet().contains(edge)) {
            throw new IllegalArgumentException("Not an edge in this model: " + edge);
        }
        double initial = this.edgeParameters.get(edge);
        if (initial == Double.NEGATIVE_INFINITY) {
            initial = Double.MIN_VALUE;
        } else if (initial == Double.POSITIVE_INFINITY) {
            initial = Double.MAX_VALUE;
        }
        double value = initial;
        double high = value + 1.0;
        while (this.paramInBounds(edge, high) && (high = value + 2.0 * (high - value)) != Double.POSITIVE_INFINITY) {
        }
        if (high == Double.POSITIVE_INFINITY) {
            rangeHigh = high;
        } else {
            low = value;
            while (high - low > 1.0E-10) {
                double midpoint = (high + low) / 2.0;
                if (this.paramInBounds(edge, midpoint)) {
                    low = midpoint;
                    continue;
                }
                high = midpoint;
            }
            rangeHigh = (high + low) / 2.0;
        }
        low = value - 1.0;
        while (this.paramInBounds(edge, low) && (low = value - 2.0 * (value - low)) != Double.NEGATIVE_INFINITY) {
        }
        if (low == Double.NEGATIVE_INFINITY) {
            rangeLow = low;
        } else {
            high = value;
            while (high - low > 1.0E-10) {
                double midpoint = (high + low) / 2.0;
                if (this.paramInBounds(edge, midpoint)) {
                    high = midpoint;
                    continue;
                }
                low = midpoint;
            }
            rangeLow = (high + low) / 2.0;
        }
        if (Edges.isDirectedEdge(edge)) {
            this.edgeParameters.put(edge, initial);
        } else if (Edges.isBidirectedEdge(edge)) {
            this.edgeParameters.put(edge, initial);
        }
        return new ParameterRange(edge, value, rangeLow, rangeHigh);
    }

    public double getErrorVariance(Node error) {
        return this.calculateErrorVarianceFromParams(error);
    }

    public Map<Node, Double> errorVariances() {
        HashMap<Node, Double> errorVarances = new HashMap<Node, Double>();
        for (Node error : this.getErrorNodes()) {
            errorVarances.put(error, this.getErrorVariance(error));
        }
        return errorVarances;
    }

    public String toString() {
        StringBuilder buf = new StringBuilder();
        NumberFormat nf = NumberFormatUtil.getInstance().getNumberFormat();
        buf.append("\nStandardized SEM:");
        buf.append("\n\nEdge coefficients (parameters):\n");
        for (Edge edge : this.edgeParameters.keySet()) {
            if (!Edges.isDirectedEdge(edge)) continue;
            buf.append("\n" + edge + " " + nf.format(this.edgeParameters.get(edge)));
        }
        buf.append("\n\nError covariances (parameters):\n");
        for (Edge edge : this.edgeParameters.keySet()) {
            if (!Edges.isBidirectedEdge(edge)) continue;
            buf.append("\n" + edge + " " + nf.format(this.edgeParameters.get(edge)));
        }
        buf.append("\n\nError variances (calculated):\n");
        for (Node error : this.getErrorNodes()) {
            double variance = this.getErrorVariance(error);
            buf.append("\n" + error + " " + nf.format(Double.isNaN(variance) ? "Undefined" : Double.valueOf(variance)));
        }
        buf.append("\n");
        return buf.toString();
    }

    public List<Node> getVariableNodes() {
        return this.semPm.getVariableNodes();
    }

    public DoubleMatrix2D edgeCoef() {
        List<Node> variableNodes = this.getVariableNodes();
        DenseDoubleMatrix2D edgeCoef = new DenseDoubleMatrix2D(variableNodes.size(), variableNodes.size());
        for (Edge edge : this.edgeParameters.keySet()) {
            if (Edges.isBidirectedEdge(edge)) continue;
            Node a = edge.getNode1();
            Node b = edge.getNode2();
            int aindex = variableNodes.indexOf(a);
            int bindex = variableNodes.indexOf(b);
            double coef = this.edgeParameters.get(edge);
            edgeCoef.set(aindex, bindex, coef);
        }
        return edgeCoef;
    }

    public DoubleMatrix2D errCovar() {
        return this.errCovar(this.errorVariances());
    }

    public double[] means() {
        return new double[this.semPm.getVariableNodes().size()];
    }

    public DataSet simulateData(int sampleSize, boolean latentDataSaved) {
        return this.simulateDataReducedForm(sampleSize, latentDataSaved);
    }

    public DataSet simulateDataReducedForm(int sampleSize, boolean latentDataSaved) {
        int numVars = this.getVariableNodes().size();
        DoubleMatrix2D edgeCoef = this.edgeCoef().copy().viewDice();
        DoubleMatrix2D iMinusB = DoubleFactory2D.dense.identity(edgeCoef.rows());
        iMinusB.assign(edgeCoef, Functions.minus);
        DoubleMatrix2D inv = new Algebra().inverse(iMinusB);
        DenseDoubleMatrix2D sim = new DenseDoubleMatrix2D(sampleSize, numVars);
        DoubleMatrix2D cholesky = MatrixUtils.choleskyC(this.errCovar(this.errorVariances()));
        for (int i = 0; i < sampleSize; ++i) {
            DenseDoubleMatrix1D e = new DenseDoubleMatrix1D(this.exogenousData(cholesky, RandomUtil.getInstance()));
            DoubleMatrix1D ePrime = new Algebra().mult(inv, (DoubleMatrix1D)e);
            sim.viewRow(i).assign(ePrime);
        }
        ColtDataSet fullDataSet = ColtDataSet.makeContinuousData(this.getVariableNodes(), sim);
        if (latentDataSaved) {
            return fullDataSet;
        }
        return DataUtils.restrictToMeasured(fullDataSet);
    }

    public DoubleMatrix2D getImplCovar() {
        return this.implCovar().copy();
    }

    public DoubleMatrix2D getImplCovarMeas() {
        return this.implCovarMeas().copy();
    }

    private DoubleMatrix2D errCovar(Map<Node, Double> errorVariances) {
        List<Node> variableNodes = this.getVariableNodes();
        ArrayList<Node> errorNodes = new ArrayList<Node>();
        for (Node node : variableNodes) {
            errorNodes.add(this.semGraph.getExogenous(node));
        }
        DenseDoubleMatrix2D errorCovar = new DenseDoubleMatrix2D(errorVariances.size(), errorVariances.size());
        for (int index = 0; index < errorNodes.size(); ++index) {
            Node error = (Node)errorNodes.get(index);
            double variance = this.getErrorVariance(error);
            errorCovar.set(index, index, variance);
        }
        for (int index1 = 0; index1 < errorNodes.size(); ++index1) {
            for (int index2 = 0; index2 < errorNodes.size(); ++index2) {
                Node error2;
                Node error1 = (Node)errorNodes.get(index1);
                Edge edge = this.semGraph.getEdge(error1, error2 = (Node)errorNodes.get(index2));
                if (edge == null || !Edges.isBidirectedEdge(edge)) continue;
                double covariance = this.getErrorCovariance(error1, error2);
                errorCovar.set(index1, index2, covariance);
            }
        }
        return errorCovar;
    }

    private DoubleMatrix2D implCovar() {
        this.computeImpliedCovar();
        return this.implCovar;
    }

    private DoubleMatrix2D implCovarMeas() {
        this.computeImpliedCovar();
        return this.implCovarMeas;
    }

    private void computeImpliedCovar() {
        DoubleMatrix2D edgeCoefT = new Algebra().transpose(this.edgeCoef());
        this.implCovar = MatrixUtils.impliedCovarC(edgeCoefT, this.errCovar(this.errorVariances()));
        int size = this.getMeasuredNodes().size();
        this.implCovarMeas = new DenseDoubleMatrix2D(size, size);
        for (int i = 0; i < size; ++i) {
            for (int j = 0; j < size; ++j) {
                Node iNode = this.getMeasuredNodes().get(i);
                Node jNode = this.getMeasuredNodes().get(j);
                int _i = this.getVariableNodes().indexOf(iNode);
                int _j = this.getVariableNodes().indexOf(jNode);
                this.implCovarMeas.set(i, j, this.implCovar.get(_i, _j));
            }
        }
    }

    public List<Node> getMeasuredNodes() {
        return this.getSemPm().getMeasuredNodes();
    }

    private double[] exogenousData(DoubleMatrix2D cholesky, RandomUtil randomUtil) {
        double[] exoData = new double[cholesky.rows()];
        for (int i = 0; i < exoData.length; ++i) {
            exoData[i] = randomUtil.nextNormal(0.0, 1.0);
        }
        double[] point = new double[exoData.length];
        for (int i = 0; i < exoData.length; ++i) {
            double sum = 0.0;
            for (int j = 0; j <= i; ++j) {
                sum += cholesky.get(i, j) * exoData[j];
            }
            point[i] = sum;
        }
        return point;
    }

    public List<Node> getErrorNodes() {
        ArrayList<Node> errorNodes = new ArrayList<Node>();
        for (Node node : this.getVariableNodes()) {
            errorNodes.add(this.semGraph.getExogenous(node));
        }
        return errorNodes;
    }

    public SemPm getSemPm() {
        return new SemPm(this.semPm);
    }

    private boolean paramInBounds(Edge edge, double newValue) {
        this.edgeParameters.put(edge, newValue);
        HashMap<Node, Double> errorVariances = new HashMap<Node, Double>();
        for (Node node : this.semPm.getVariableNodes()) {
            Node error = this.semGraph.getExogenous(node);
            double d2 = this.calculateErrorVarianceFromParams(error);
            if (Double.isNaN(d2)) {
                return false;
            }
            errorVariances.put(error, d2);
        }
        return MatrixUtils.isPositiveDefinite(this.errCovar(errorVariances));
    }

    private double calculateErrorVarianceFromParams(Node error) {
        error = this.semGraph.getNode(error.getName());
        Node child = this.semGraph.getChildren(error).get(0);
        List<Node> parents = this.semGraph.getParents(child);
        double otherVariance = 0.0;
        for (Node parent : parents) {
            if (parent == error) continue;
            double coef = this.getEdgeCoefficient(parent, child);
            otherVariance += coef * coef;
        }
        if (parents.size() >= 2) {
            int[] indices;
            ChoiceGenerator gen = new ChoiceGenerator(parents.size(), 2);
            while ((indices = gen.next()) != null) {
                Node node1 = parents.get(indices[0]);
                Node node2 = parents.get(indices[1]);
                double coef1 = node1.getNodeType() != NodeType.ERROR ? this.getEdgeCoefficient(node1, child) : 1.0;
                double coef2 = node2.getNodeType() != NodeType.ERROR ? this.getEdgeCoefficient(node2, child) : 1.0;
                List<List<Node>> treks = GraphUtils.treksIncludingBidirected(this.semGraph, node1, node2);
                double cov = 0.0;
                for (List<Node> trek : treks) {
                    double product = 1.0;
                    for (int i = 1; i < trek.size(); ++i) {
                        Node _node2;
                        Node _node1 = trek.get(i - 1);
                        Edge edge = this.semGraph.getEdge(_node1, _node2 = trek.get(i));
                        double factor = Edges.isBidirectedEdge(edge) ? this.edgeParameters.get(edge) : (!this.edgeParameters.containsKey(edge) ? 1.0 : (this.semGraph.isParentOf(_node1, _node2) ? this.getEdgeCoefficient(_node1, _node2) : this.getEdgeCoefficient(_node2, _node1)));
                        product *= factor;
                    }
                    cov += product;
                }
                otherVariance += 2.0 * coef1 * coef2 * cov;
            }
        }
        return 1.0 - otherVariance <= 0.0 ? Double.NaN : 1.0 - otherVariance;
    }

    public static final class ParameterRange
    implements TetradSerializable {
        static final long serialVersionUID = 23L;
        private Edge edge;
        private double coef;
        private double low;
        private double high;

        public ParameterRange(Edge edge, double coef, double low, double high) {
            this.edge = edge;
            this.coef = coef;
            this.low = low;
            this.high = high;
        }

        public static ParameterRange serializableInstance() {
            return new ParameterRange(Edge.serializableInstance(), 1.0, 1.0, 1.0);
        }

        public Edge getEdge() {
            return this.edge;
        }

        public double getCoef() {
            return this.coef;
        }

        public double getLow() {
            return this.low;
        }

        public double getHigh() {
            return this.high;
        }

        public String toString() {
            StringBuilder buf = new StringBuilder();
            buf.append("\n\nRange for " + this.edge);
            buf.append("\nCurrent value = " + this.coef);
            buf.append("\nLow end of range = " + this.low);
            buf.append("\nHigh end of range = " + this.high);
            return buf.toString();
        }
    }

    public static enum Initialization {
        CALCULATE_FROM_SEM,
        INITIALIZE_FROM_DATA;

    }
}

