/*
 * Decompiled with CFR 0.152.
 */
package edu.cmu.tetrad.search;

import edu.cmu.tetrad.data.DataSet;
import edu.cmu.tetrad.graph.Node;
import edu.cmu.tetrad.regression.RegressionDataset;
import edu.cmu.tetrad.regression.RegressionResult;
import edu.cmu.tetrad.util.StatUtils;
import edu.cmu.tetrad.util.Vector;
import java.util.ArrayList;
import java.util.HashMap;
import java.util.HashSet;
import java.util.List;
import java.util.Map;
import java.util.Set;
import org.apache.commons.math3.distribution.NormalDistribution;
import org.apache.commons.math3.util.FastMath;

public final class ConditionalCorrelationIndependenceLingam {
    private final RegressionDataset regressionDataset;
    private final DataSet dataSet;
    private final List<Node> variables;
    private final HashMap<Node, Integer> nodesHash;
    private double alpha;
    private double score;
    private int numFunctions = 10;
    private double cutoff;
    private double width = 1.0;
    private Basis basis = Basis.Polynomial;

    public ConditionalCorrelationIndependenceLingam(DataSet dataSet, double alpha) {
        if (dataSet == null) {
            throw new NullPointerException();
        }
        this.alpha = alpha;
        this.variables = dataSet.getVariables();
        this.nodesHash = new HashMap();
        for (int i = 0; i < this.variables.size(); ++i) {
            this.nodesHash.put(this.variables.get(i), i);
        }
        for (int j = 0; j < dataSet.getNumColumns(); ++j) {
            this.scale(dataSet, j);
        }
        this.dataSet = dataSet;
        this.regressionDataset = new RegressionDataset(dataSet);
    }

    public double isIndependent(Node x, Node y, List<Node> z) {
        try {
            double score;
            double[] rx = this.residuals(x, z);
            double[] ry = this.residuals(y, z);
            this.score = score = this.independent(rx, ry);
            return score;
        }
        catch (Exception e) {
            e.printStackTrace();
            return 0.0;
        }
    }

    public double[] residuals(Node x, List<Node> z) {
        RegressionResult result = this.regressionDataset.regress(x, z);
        Vector residuals = result.getResiduals();
        int numRows = this.dataSet.getNumRows();
        double[] _residualsx = new double[numRows];
        for (int i = 0; i < numRows; ++i) {
            _residualsx[i] = residuals.get(i);
        }
        return _residualsx;
    }

    public int getNumFunctions() {
        return this.numFunctions;
    }

    public void setNumFunctions(int numFunctions) {
    }

    public void setBasis(Basis basis) {
    }

    public double getWidth() {
        return this.width;
    }

    public void setWidth(double width) {
        this.width = width;
    }

    public double getPValue() {
        return this.getPValue(this.score);
    }

    public double getPValue(double score) {
        return 2.0 * new NormalDistribution(0.0, 1.0).cumulativeProbability(-FastMath.abs(score));
    }

    public double getScore() {
        return FastMath.abs(this.score) - this.cutoff;
    }

    public void setAlpha(double alpha) {
        this.alpha = alpha;
        this.cutoff = StatUtils.getZForAlpha(alpha);
    }

    public double getAlpha() {
        return this.alpha;
    }

    private double independent(double[] x, double[] y) {
        double[] _x = new double[x.length];
        double[] _y = new double[y.length];
        double maxScore = Double.NEGATIVE_INFINITY;
        for (int m = 1; m <= this.getNumFunctions(); ++m) {
            for (int n = 1; n <= this.getNumFunctions(); ++n) {
                for (int i = 0; i < x.length; ++i) {
                    _x[i] = this.function(m, x[i]);
                    _y[i] = this.function(n, y[i]);
                }
                double score = FastMath.abs(this.nonparametricFisherZ(_x, _y));
                if (Double.isInfinite(score) || Double.isNaN(score) || !(score > maxScore)) continue;
                maxScore = score;
            }
        }
        return maxScore;
    }

    private void scale(DataSet dataSet, int col) {
        double d;
        int i;
        double max = Double.MIN_VALUE;
        double min = Double.MAX_VALUE;
        for (i = 0; i < dataSet.getNumRows(); ++i) {
            d = dataSet.getDouble(i, col);
            if (Double.isNaN(d)) continue;
            if (d > max) {
                max = d;
            }
            if (!(d < min)) continue;
            min = d;
        }
        for (i = 0; i < dataSet.getNumRows(); ++i) {
            d = dataSet.getDouble(i, col);
            if (Double.isNaN(d)) continue;
            dataSet.setDouble(i, col, min + (d - min) / (max - min));
        }
    }

    private double nonparametricFisherZ(double[] _x, double[] _y) {
        double[] __x = this.standardize(_x);
        double[] __y = this.standardize(_y);
        double r = StatUtils.covariance(__x, __y);
        int N = __x.length;
        double z = 0.5 * FastMath.sqrt(N) * (FastMath.log(1.0 + r) - FastMath.log(1.0 - r));
        return z / FastMath.sqrt(this.moment22(__x, __y));
    }

    private double moment22(double[] x, double[] y) {
        int N = x.length;
        double sum = 0.0;
        for (int j = 0; j < x.length; ++j) {
            sum += x[j] * x[j] * y[j] * y[j];
        }
        return sum / (double)N;
    }

    private double function(int index, double x) {
        if (this.basis == Basis.Polynomial) {
            double g = 1.0;
            for (int i = 1; i <= index; ++i) {
                g *= x;
            }
            if (FastMath.abs(g) == Double.POSITIVE_INFINITY) {
                g = Double.NaN;
            }
            return g;
        }
        if (this.basis == Basis.Cosine) {
            int i = (index + 1) / 2;
            if (index % 2 == 1) {
                return FastMath.sin((double)i * x);
            }
            return FastMath.cos((double)i * x);
        }
        throw new IllegalStateException("That basis is not configured: " + (Object)((Object)this.basis));
    }

    private double h(double[] xCol) {
        double[] g = new double[xCol.length];
        double median = StatUtils.median(xCol);
        for (int j = 0; j < xCol.length; ++j) {
            g[j] = FastMath.abs(xCol[j] - median);
        }
        double mad = StatUtils.median(g);
        return 1.4826 * mad * FastMath.pow(1.3333333333333333 / (double)xCol.length, 0.2);
    }

    private double kernelEpinechnikov(double z, double h) {
        if (FastMath.abs(z /= this.getWidth() * h) > 1.0) {
            return 0.0;
        }
        return 1.0 - z * z;
    }

    private double kernelGaussian(double z, double h) {
        return FastMath.exp(-(z /= this.getWidth() * h) * z);
    }

    private double distance(double[][] data, int[] z, int i, int j) {
        double sum = 0.0;
        for (int _z : z) {
            double d = (data[_z][i] - data[_z][j]) / 2.0;
            if (Double.isNaN(d)) continue;
            sum += d * d;
        }
        return FastMath.sqrt(sum);
    }

    private double[] standardize(double[] data) {
        double sum = 0.0;
        for (double d : data) {
            sum += d;
        }
        double mean = sum / (double)data.length;
        for (int i = 0; i < data.length; ++i) {
            data[i] = data[i] - mean;
        }
        double var = 0.0;
        for (double d : data) {
            var += d * d;
        }
        double sd = FastMath.sqrt(var /= (double)data.length);
        int i = 0;
        while (i < data.length) {
            int n = i++;
            data[n] = data[n] / sd;
        }
        return data;
    }

    private Set<Integer> getCloseZs(double[][] _data, int[] _z, int i, int sampleSize, List<Map<Integer, Integer>> reverseLookup, List<List<Integer>> sortedIndices) {
        HashSet<Integer> js = new HashSet<Integer>();
        if (sampleSize > _data[0].length) {
            sampleSize = (int)FastMath.ceil(0.8 * (double)_data.length);
        }
        if (_z.length == 0) {
            return new HashSet<Integer>();
        }
        int radius = 0;
        while (true) {
            for (int z1 : _z) {
                int r2;
                int q = reverseLookup.get(z1).get(i);
                if (q - radius >= 0 && q - radius < _data[z1 + 1].length) {
                    r2 = sortedIndices.get(z1).get(q - radius);
                    js.add(r2);
                }
                if (q + radius < 0 || q + radius >= _data[z1 + 1].length) continue;
                r2 = sortedIndices.get(z1).get(q + radius);
                js.add(r2);
            }
            if (js.size() >= sampleSize) {
                return js;
            }
            ++radius;
        }
    }

    private List<Integer> getRows(DataSet dataSet, List<Node> allVars, Map<Node, Integer> nodesHash) {
        ArrayList<Integer> rows = new ArrayList<Integer>();
        block0: for (int k = 0; k < dataSet.getNumRows(); ++k) {
            for (Node node : allVars) {
                if (!Double.isNaN(dataSet.getDouble(k, nodesHash.get(node)))) continue;
                continue block0;
            }
            rows.add(k);
        }
        return rows;
    }

    public static enum Basis {
        Polynomial,
        Cosine;

    }

    public static enum Kernel {
        Epinechnikov,
        Gaussian;

    }
}

