/*
 * Decompiled with CFR 0.152.
 */
package edu.cmu.tetrad.search;

import cern.colt.matrix.DoubleMatrix1D;
import cern.colt.matrix.DoubleMatrix2D;
import cern.colt.matrix.impl.DenseDoubleMatrix2D;
import cern.colt.matrix.linalg.Algebra;
import cern.jet.math.Functions;
import edu.cmu.tetrad.data.DataSet;
import edu.cmu.tetrad.data.ICovarianceMatrix;
import edu.cmu.tetrad.graph.IndependenceFact;
import edu.cmu.tetrad.graph.Node;
import edu.cmu.tetrad.search.IndependenceResult;
import edu.cmu.tetrad.search.IndependenceTest;
import edu.cmu.tetrad.search.SearchLogUtils;
import edu.cmu.tetrad.util.Matrix;
import edu.cmu.tetrad.util.NumberFormatUtil;
import edu.cmu.tetrad.util.RandomUtil;
import edu.cmu.tetrad.util.StatUtils;
import edu.cmu.tetrad.util.TetradLogger;
import java.text.NumberFormat;
import java.util.ArrayList;
import java.util.Collections;
import java.util.List;
import org.apache.commons.math3.util.FastMath;

public final class IndTestFisherZGeneralizedInverse
implements IndependenceTest {
    private final DoubleMatrix2D data;
    private final List<Node> variables;
    private double alpha;
    private double thresh = Double.NaN;
    private double fishersZ;
    private static final NumberFormat nf = NumberFormatUtil.getInstance().getNumberFormat();
    private final DataSet dataSet;
    private boolean verbose;

    public IndTestFisherZGeneralizedInverse(DataSet dataSet, double alpha) {
        if (!(alpha >= 0.0) || !(alpha <= 1.0)) {
            throw new IllegalArgumentException("Alpha mut be in [0, 1]");
        }
        this.dataSet = dataSet;
        this.data = new DenseDoubleMatrix2D(dataSet.getDoubleData().toArray());
        this.variables = Collections.unmodifiableList(dataSet.getVariables());
        this.setAlpha(alpha);
    }

    @Override
    public IndependenceTest indTestSubset(List<Node> vars) {
        return null;
    }

    @Override
    public IndependenceResult checkIndependence(Node xVar, Node yVar, List<Node> z) {
        boolean indFisher;
        if (z == null) {
            throw new NullPointerException();
        }
        for (Node node : z) {
            if (node != null) continue;
            throw new NullPointerException();
        }
        int size = z.size();
        int[] zCols = new int[size];
        int xIndex = this.getVariables().indexOf(xVar);
        int yIndex = this.getVariables().indexOf(yVar);
        for (int i = 0; i < z.size(); ++i) {
            zCols[i] = this.getVariables().indexOf(z.get(i));
        }
        int[] zRows = new int[this.data.rows()];
        for (int i = 0; i < this.data.rows(); ++i) {
            zRows[i] = i;
        }
        DoubleMatrix2D Z = this.data.viewSelection(zRows, zCols);
        DoubleMatrix1D x = this.data.viewColumn(xIndex);
        DoubleMatrix1D y = this.data.viewColumn(yIndex);
        DoubleMatrix2D Zt = new Algebra().transpose(Z);
        DoubleMatrix2D ZtZ = new Algebra().mult(Zt, Z);
        Matrix _ZtZ = new Matrix(ZtZ.toArray());
        Matrix ginverse = _ZtZ.inverse();
        DenseDoubleMatrix2D G = new DenseDoubleMatrix2D(ginverse.toArray());
        DoubleMatrix2D Zt2 = Zt.like();
        Zt2.assign(Zt);
        DoubleMatrix2D GZt = new Algebra().mult((DoubleMatrix2D)G, Zt2);
        DoubleMatrix1D b_x = new Algebra().mult(GZt, x);
        DoubleMatrix1D b_y = new Algebra().mult(GZt, y);
        DoubleMatrix1D xPred = new Algebra().mult(Z, b_x);
        DoubleMatrix1D yPred = new Algebra().mult(Z, b_y);
        DoubleMatrix1D xRes = xPred.copy().assign(x, Functions.minus);
        DoubleMatrix1D yRes = yPred.copy().assign(y, Functions.minus);
        double r = StatUtils.correlation(xRes.toArray(), yRes.toArray());
        if (Double.isNaN(this.thresh)) {
            this.thresh = this.cutoffGaussian();
        }
        if (Double.isNaN(r)) {
            if (this.verbose) {
                TetradLogger.getInstance().log("independencies", SearchLogUtils.independenceFactMsg(xVar, yVar, z, this.getPValue()));
            }
            return new IndependenceResult(new IndependenceFact(xVar, yVar, z), false, Double.NaN);
        }
        if (r > 1.0) {
            r = 1.0;
        }
        if (r < -1.0) {
            r = -1.0;
        }
        this.fishersZ = FastMath.sqrt((double)(this.sampleSize() - z.size()) - 3.0) * 0.5 * (FastMath.log(1.0 + r) - FastMath.log(1.0 - r));
        if (Double.isNaN(this.fishersZ)) {
            throw new IllegalArgumentException("The Fisher's Z score for independence fact " + xVar + " _||_ " + yVar + " | " + z + " is undefined.");
        }
        boolean bl = indFisher = !(FastMath.abs(this.fishersZ) > this.thresh);
        if (this.verbose) {
            TetradLogger.getInstance().log("independencies", SearchLogUtils.independenceFactMsg(xVar, yVar, z, this.getPValue()));
        }
        if (this.verbose && indFisher) {
            TetradLogger.getInstance().forceLogMessage(SearchLogUtils.independenceFactMsg(xVar, yVar, z, this.getPValue()));
        }
        return new IndependenceResult(new IndependenceFact(xVar, yVar, z), indFisher, this.getPValue());
    }

    public double getPValue() {
        return 2.0 * (1.0 - RandomUtil.getInstance().normalCdf(0.0, 1.0, FastMath.abs(this.fishersZ)));
    }

    @Override
    public void setAlpha(double alpha) {
        if (alpha < 0.0 || alpha > 1.0) {
            throw new IllegalArgumentException("Significance out of range.");
        }
        this.alpha = alpha;
    }

    @Override
    public double getAlpha() {
        return this.alpha;
    }

    @Override
    public List<Node> getVariables() {
        return this.variables;
    }

    @Override
    public Node getVariable(String name) {
        for (int i = 0; i < this.getVariables().size(); ++i) {
            Node variable = this.getVariables().get(i);
            if (!variable.getName().equals(name)) continue;
            return variable;
        }
        return null;
    }

    @Override
    public List<String> getVariableNames() {
        List<Node> variables = this.getVariables();
        ArrayList<String> variableNames = new ArrayList<String>();
        for (Node variable1 : variables) {
            variableNames.add(variable1.getName());
        }
        return variableNames;
    }

    @Override
    public String toString() {
        return "Fisher's Z - Generalized Inverse, alpha = " + nf.format(this.getAlpha());
    }

    private double cutoffGaussian() {
        double upperTail = 1.0 - this.getAlpha() / 2.0;
        double epsilon = 1.0E-14;
        double lowerBound = -1.0;
        double upperBound = 0.0;
        while (RandomUtil.getInstance().normalCdf(0.0, 1.0, upperBound) < upperTail) {
            lowerBound += 1.0;
            upperBound += 1.0;
        }
        while (upperBound >= lowerBound + 1.0E-14) {
            double midPoint = lowerBound + (upperBound - lowerBound) / 2.0;
            if (RandomUtil.getInstance().normalCdf(0.0, 1.0, midPoint) <= upperTail) {
                lowerBound = midPoint;
                continue;
            }
            upperBound = midPoint;
        }
        return lowerBound;
    }

    private int sampleSize() {
        return this.data.rows();
    }

    @Override
    public boolean determines(List<Node> zList, Node xVar) {
        boolean determined;
        DoubleMatrix2D Zt2;
        if (zList == null) {
            throw new NullPointerException();
        }
        if (zList.isEmpty()) {
            return false;
        }
        for (Node node : zList) {
            if (node != null) continue;
            throw new NullPointerException();
        }
        int size = zList.size();
        int[] zCols = new int[size];
        int xIndex = this.getVariables().indexOf(xVar);
        for (int i = 0; i < zList.size(); ++i) {
            zCols[i] = this.getVariables().indexOf(zList.get(i));
        }
        int[] zRows = new int[this.data.rows()];
        for (int i = 0; i < this.data.rows(); ++i) {
            zRows[i] = i;
        }
        DoubleMatrix2D Z = this.data.viewSelection(zRows, zCols);
        DoubleMatrix1D x = this.data.viewColumn(xIndex);
        DoubleMatrix2D Zt = new Algebra().transpose(Z);
        DoubleMatrix2D ZtZ = new Algebra().mult(Zt, Z);
        Matrix _ZtZ = new Matrix(ZtZ.toArray());
        Matrix ginverse = _ZtZ.inverse();
        DenseDoubleMatrix2D G = new DenseDoubleMatrix2D(ginverse.toArray());
        DoubleMatrix2D GZt = new Algebra().mult((DoubleMatrix2D)G, Zt2 = Zt.copy());
        DoubleMatrix1D b_x = new Algebra().mult(GZt, x);
        DoubleMatrix1D xPred = new Algebra().mult(Z, b_x);
        DoubleMatrix1D xRes = xPred.copy().assign(x, Functions.minus);
        double SSE = xRes.aggregate(Functions.plus, Functions.square);
        double variance = SSE / (double)(this.data.rows() - (zList.size() + 1));
        boolean bl = determined = variance < this.getAlpha();
        if (determined) {
            StringBuilder sb = new StringBuilder();
            sb.append("Determination found: ").append(xVar).append(" is determined by {");
            for (int i = 0; i < zList.size(); ++i) {
                sb.append(zList.get(i));
                if (i >= zList.size() - 1) continue;
                sb.append(", ");
            }
            sb.append("}");
            sb.append(" SSE = ").append(nf.format(SSE));
            TetradLogger.getInstance().log("independencies", sb.toString());
            System.out.println(sb);
        }
        return determined;
    }

    @Override
    public DataSet getData() {
        return this.dataSet;
    }

    @Override
    public ICovarianceMatrix getCov() {
        return null;
    }

    @Override
    public List<DataSet> getDataSets() {
        return null;
    }

    @Override
    public int getSampleSize() {
        return 0;
    }

    @Override
    public List<Matrix> getCovMatrices() {
        return null;
    }

    @Override
    public double getScore() {
        return this.getPValue();
    }

    @Override
    public boolean isVerbose() {
        return this.verbose;
    }

    @Override
    public void setVerbose(boolean verbose) {
        this.verbose = verbose;
    }
}

