/*
 * Decompiled with CFR 0.152.
 */
package edu.cmu.tetrad.sem;

import edu.cmu.tetrad.data.DataSet;
import edu.cmu.tetrad.graph.Node;
import edu.cmu.tetrad.graph.NodeType;
import edu.cmu.tetrad.graph.SemGraph;
import edu.cmu.tetrad.sem.SemIm;
import edu.cmu.tetrad.sem.SemPm;
import edu.cmu.tetrad.util.MatrixUtils;
import edu.cmu.tetrad.util.ProbUtils;
import java.util.Iterator;
import java.util.LinkedList;
import java.util.List;

public class Tsls {
    private SemPm spm;
    private SemIm semIm;
    private List<String> fixedLoadings;
    private DataSet dataSet;
    private String nodeName;
    private double[][] asymptLCovar;
    private double[] A_hat;
    private String[] lNames;

    public Tsls(SemPm spm, DataSet dataSet, String nm) {
        this.initialization(spm, dataSet, nm);
    }

    public Tsls(SemPm spm, DataSet dataSet, String nm, List<String> fixedLoadings) {
        this.initialization(spm, dataSet, nm);
        this.fixedLoadings = fixedLoadings;
    }

    private void initialization(SemPm spm, DataSet dataSet, String nm) {
        this.dataSet = dataSet;
        this.spm = spm;
        if (nm != null) {
            this.nodeName = nm;
        }
        this.semIm = null;
    }

    public SemIm estimate() {
        this.semIm = new SemIm(this.spm);
        this.semIm = this.estimateCoeffs(this.semIm);
        return this.semIm;
    }

    public SemIm getEstimatedSem() {
        return this.semIm;
    }

    private void setFixedNodes(SemGraph semGraph, List<Node> mx1, List<Node> my1) {
        if (this.fixedLoadings == null) {
            for (Node nodeA : semGraph.getNodes()) {
                if (nodeA.getNodeType() == NodeType.ERROR || nodeA.getNodeType() != NodeType.LATENT) continue;
                Iterator<Node> children = semGraph.getChildren(nodeA).iterator();
                Node nodeB = null;
                while (children.hasNext()) {
                    Node child = children.next();
                    if (child.getNodeType() != NodeType.MEASURED || nodeB != null && child.getName().compareTo(nodeB.getName()) >= 0) continue;
                    nodeB = child;
                }
                if (semGraph.getParents(nodeA).size() == 0) {
                    mx1.add(nodeB);
                    continue;
                }
                my1.add(nodeB);
            }
        } else {
            block2: for (Node nodeA : semGraph.getNodes()) {
                if (nodeA.getNodeType() == NodeType.ERROR || nodeA.getNodeType() != NodeType.LATENT) continue;
                for (String fixedLoading : this.fixedLoadings) {
                    Node indicator = semGraph.getNode(fixedLoading);
                    for (Node parent : semGraph.getParents(indicator)) {
                        if (parent != nodeA) continue;
                        if (semGraph.getParents(parent).size() == 0) {
                            System.out.println("Fixing mx1 = " + indicator.getName());
                            mx1.add(indicator);
                            continue block2;
                        }
                        System.out.println("Fixing my1 = " + indicator.getName());
                        my1.add(indicator);
                        continue block2;
                    }
                }
            }
        }
    }

    protected SemIm estimateCoeffs(SemIm semIm) {
        SemGraph semGraph = semIm.getSemPm().getGraph();
        LinkedList<Node> ly = new LinkedList<Node>();
        LinkedList<Node> lx = new LinkedList<Node>();
        LinkedList<Node> my1 = new LinkedList<Node>();
        LinkedList<Node> mx1 = new LinkedList<Node>();
        LinkedList<Node> observed = new LinkedList<Node>();
        for (Node nodeA : semGraph.getNodes()) {
            if (nodeA.getNodeType() == NodeType.ERROR) continue;
            if (nodeA.getNodeType() == NodeType.LATENT) {
                if (semGraph.getParents(nodeA).size() == 0) {
                    lx.add(nodeA);
                    continue;
                }
                ly.add(nodeA);
                continue;
            }
            observed.add(nodeA);
        }
        this.setFixedNodes(semGraph, mx1, my1);
        for (Node current : ly) {
            int i;
            int i2;
            int colIndex;
            Node variable;
            String name;
            Node node;
            int position;
            int j;
            int colIndex2;
            String name2;
            int i3;
            int j2;
            int colIndex3;
            Node variable2;
            String name3;
            Node node2;
            int i4;
            if (this.nodeName != null && !this.nodeName.equals(current.getName())) continue;
            LinkedList endo_parents_m = new LinkedList();
            LinkedList exo_parents_m = new LinkedList();
            LinkedList<Node> endo_parents = new LinkedList<Node>();
            LinkedList<Node> exo_parents = new LinkedList<Node>();
            Iterator<Node> it_p = semGraph.getParents(current).iterator();
            this.lNames = new String[lx.size() + ly.size()];
            while (it_p.hasNext()) {
                int position2;
                Node node3 = it_p.next();
                if (node3.getNodeType() == NodeType.ERROR) continue;
                if (lx.contains(node3)) {
                    position2 = lx.indexOf(node3);
                    exo_parents_m.add(mx1.get(position2));
                    exo_parents.add(node3);
                    continue;
                }
                position2 = ly.indexOf(node3);
                endo_parents_m.add(my1.get(position2));
                endo_parents.add(node3);
            }
            Object[] endp_a_m = endo_parents_m.toArray();
            Object[] exop_a_m = exo_parents_m.toArray();
            Object[] endp_a = endo_parents.toArray();
            Object[] exop_a = exo_parents.toArray();
            int n = this.dataSet.getNumRows();
            int c = endp_a_m.length + exop_a_m.length;
            if (c == 0) continue;
            double[][] Z = new double[n][c];
            int count = 0;
            for (i4 = 0; i4 < endp_a_m.length; ++i4) {
                node2 = (Node)endp_a_m[i4];
                name3 = node2.getName();
                variable2 = this.dataSet.getVariable(name3);
                colIndex3 = this.dataSet.getVariables().indexOf(variable2);
                for (j2 = 0; j2 < n; ++j2) {
                    Z[j2][i4] = this.dataSet.getDouble(j2, colIndex3);
                }
                this.lNames[count++] = ((Node)endo_parents.get(i4)).getName();
            }
            for (i4 = 0; i4 < exop_a_m.length; ++i4) {
                node2 = (Node)exop_a_m[i4];
                name3 = node2.getName();
                variable2 = this.dataSet.getVariable(name3);
                colIndex3 = this.dataSet.getVariables().indexOf(variable2);
                for (j2 = 0; j2 < n; ++j2) {
                    Z[j2][endp_a_m.length + i4] = this.dataSet.getDouble(j2, colIndex3);
                }
                this.lNames[count++] = ((Node)exo_parents.get(i4)).getName();
            }
            endo_parents_m = new LinkedList();
            exo_parents_m = new LinkedList();
            for (Node node4 : semGraph.getParents(current)) {
                int position3;
                if (node4.getNodeType() == NodeType.ERROR) continue;
                LinkedList<Node> other_measures = new LinkedList<Node>();
                for (Node next : semGraph.getChildren(node4)) {
                    if (next.getNodeType() != NodeType.MEASURED) continue;
                    other_measures.add(next);
                }
                if (lx.contains(node4)) {
                    position3 = lx.indexOf(node4);
                    other_measures.remove(mx1.get(position3));
                    exo_parents_m.addAll(other_measures);
                    continue;
                }
                position3 = ly.indexOf(node4);
                other_measures.remove(my1.get(position3));
                endo_parents_m.addAll(other_measures);
            }
            endp_a_m = endo_parents_m.toArray();
            exop_a_m = exo_parents_m.toArray();
            n = this.dataSet.getNumRows();
            c = endp_a_m.length + exop_a_m.length;
            double[][] V = new double[n][c];
            if (c == 0) continue;
            for (i3 = 0; i3 < endp_a_m.length; ++i3) {
                Node node5 = (Node)endp_a_m[i3];
                name2 = node5.getName();
                Node variable3 = this.dataSet.getVariable(name2);
                colIndex2 = this.dataSet.getVariables().indexOf(variable3);
                for (j = 0; j < n; ++j) {
                    V[j][i3] = this.dataSet.getDouble(j, colIndex2);
                }
            }
            for (i3 = 0; i3 < exop_a_m.length; ++i3) {
                Node node6 = (Node)exop_a_m[i3];
                name2 = node6.getName();
                Node variable4 = this.dataSet.getVariable(name2);
                colIndex2 = this.dataSet.getVariables().indexOf(variable4);
                for (j = 0; j < n; ++j) {
                    V[j][endp_a_m.length + i3] = this.dataSet.getDouble(j, colIndex2);
                }
            }
            double[] yi = new double[n];
            if (lx.contains(current)) {
                position = lx.indexOf(current);
                node = (Node)mx1.get(position);
                name = node.getName();
                variable = this.dataSet.getVariable(name);
                colIndex = this.dataSet.getVariables().indexOf(variable);
                for (i2 = 0; i2 < n; ++i2) {
                    yi[i2] = this.dataSet.getDouble(i2, colIndex);
                }
            } else {
                position = ly.indexOf(current);
                node = (Node)my1.get(position);
                name = node.getName();
                variable = this.dataSet.getVariable(name);
                colIndex = this.dataSet.getVariables().indexOf(variable);
                for (i2 = 0; i2 < n; ++i2) {
                    yi[i2] = this.dataSet.getDouble(i2, colIndex);
                }
            }
            double[][] Z_hat = MatrixUtils.product(V, MatrixUtils.product(MatrixUtils.inverse(MatrixUtils.product(MatrixUtils.transpose(V), V)), MatrixUtils.product(MatrixUtils.transpose(V), Z)));
            this.A_hat = MatrixUtils.product(MatrixUtils.inverse(MatrixUtils.product(MatrixUtils.transpose(Z_hat), Z_hat)), MatrixUtils.product(MatrixUtils.transpose(Z_hat), yi));
            int position4 = ly.indexOf(current);
            semIm.setParamValue(current, (Node)my1.get(position4), 1.0);
            for (i = 0; i < endp_a.length; ++i) {
                semIm.setParamValue((Node)endp_a[i], current, this.A_hat[i]);
            }
            for (i = 0; i < exop_a.length; ++i) {
                semIm.setParamValue((Node)exop_a[i], current, this.A_hat[endp_a.length + i]);
            }
            if (this.nodeName == null || !this.nodeName.equals(current.getName())) continue;
            this.computeAsymptLatentCovar(yi, this.A_hat, Z, Z_hat, this.dataSet.getNumRows());
            break;
        }
        for (Node current : lx) {
            int position = lx.indexOf(current);
            semIm.setParamValue(current, (Node)mx1.get(position), 1.0);
        }
        for (Node current : observed) {
            Node fixed_measurement;
            int position;
            if (this.nodeName != null && !this.nodeName.equals(current.getName()) || mx1.contains(current) || my1.contains(current)) continue;
            Node current_latent = null;
            for (Node node : semGraph.getParents(current)) {
                if (node.getNodeType() == NodeType.ERROR) continue;
                current_latent = node;
            }
            Iterator<Node> children = semGraph.getChildren(current_latent).iterator();
            LinkedList<Node> other_measures = new LinkedList<Node>();
            while (children.hasNext()) {
                Node next = children.next();
                if (next.getNodeType() != NodeType.MEASURED || next == current) continue;
                other_measures.add(next);
            }
            if (lx.contains(current_latent)) {
                position = lx.indexOf(current_latent);
                other_measures.remove(mx1.get(position));
                fixed_measurement = (Node)mx1.get(position);
            } else {
                position = ly.indexOf(current_latent);
                other_measures.remove(my1.get(position));
                fixed_measurement = (Node)my1.get(position);
            }
            int n = this.dataSet.getNumRows();
            int c = other_measures.size();
            if (c == 0) continue;
            double[][] Z = new double[n][c];
            for (int i = 0; i < c; ++i) {
                Node variable = this.dataSet.getVariable(((Node)other_measures.get(i)).getName());
                int varIndex = this.dataSet.getVariables().indexOf(variable);
                for (int j = 0; j < n; ++j) {
                    Z[j][i] = this.dataSet.getDouble(varIndex, j);
                }
            }
            Node variable = this.dataSet.getVariable(fixed_measurement.getName());
            int colIndex = this.dataSet.getVariables().indexOf(variable);
            double[] C = new double[this.dataSet.getNumRows()];
            for (int i = 0; i < this.dataSet.getNumRows(); ++i) {
                C[i] = this.dataSet.getDouble(colIndex, i);
            }
            n = this.dataSet.getNumRows();
            c = other_measures.size();
            double[][] V = new double[n][c];
            for (int i = 0; i < c; ++i) {
                Node variable2 = this.dataSet.getVariable(((Node)other_measures.get(i)).getName());
                int var2index = this.dataSet.getVariables().indexOf(variable2);
                for (int j = 0; j < n; ++j) {
                    V[j][i] = this.dataSet.getDouble(j, var2index);
                }
            }
            double[] yi = new double[n];
            Node variable3 = this.dataSet.getVariable(current.getName());
            int var3Index = this.dataSet.getVariables().indexOf(variable3);
            for (int i = 0; i < n; ++i) {
                yi[i] = this.dataSet.getDouble(i, var3Index);
            }
            double[] C_hat = MatrixUtils.product(V, MatrixUtils.product(MatrixUtils.inverse(MatrixUtils.product(MatrixUtils.transpose(V), V)), MatrixUtils.product(MatrixUtils.transpose(V), C)));
            double A_hat = MatrixUtils.innerProduct(MatrixUtils.scalarProduct(1.0 / MatrixUtils.innerProduct(C_hat, C_hat), C_hat), yi);
            semIm.setParamValue(current_latent, current, A_hat);
        }
        return semIm;
    }

    private void computeAsymptLatentCovar(double[] y, double[] A_hat, double[][] Z, double[][] Z_hat, double n) {
        double[] yza = MatrixUtils.subtract(y, MatrixUtils.product(Z, A_hat));
        double sigma_ui = MatrixUtils.innerProduct(yza, yza) / n;
        for (double[] anAsymptLCovar : this.asymptLCovar = MatrixUtils.inverse(MatrixUtils.product(MatrixUtils.transpose(Z_hat), Z_hat))) {
            int j = 0;
            while (j < this.asymptLCovar.length) {
                int n2 = j++;
                anAsymptLCovar[n2] = anAsymptLCovar[n2] * sigma_ui;
            }
        }
    }

    public double getEdgePValue(String source) {
        if (this.asymptLCovar == null) {
            return 0.0;
        }
        for (int i = 0; i < this.lNames.length; ++i) {
            if (!this.lNames[i].equals(source)) continue;
            double z = Math.abs(this.A_hat[i] / Math.sqrt(this.asymptLCovar[i][i]));
            System.out.println("Asymptotic Z = " + z);
            return 2.0 * (1.0 - ProbUtils.normalCdf(z));
        }
        return 0.0;
    }
}

