/*
 * Decompiled with CFR 0.152.
 */
package edu.cmu.tetrad.search.test;

import edu.cmu.tetrad.data.BoxDataSet;
import edu.cmu.tetrad.data.CovarianceMatrix;
import edu.cmu.tetrad.data.DataSet;
import edu.cmu.tetrad.data.DataTransforms;
import edu.cmu.tetrad.data.DoubleDataBox;
import edu.cmu.tetrad.data.ICovarianceMatrix;
import edu.cmu.tetrad.graph.GraphUtils;
import edu.cmu.tetrad.graph.IndependenceFact;
import edu.cmu.tetrad.graph.Node;
import edu.cmu.tetrad.regression.Regression;
import edu.cmu.tetrad.regression.RegressionDataset;
import edu.cmu.tetrad.regression.RegressionResult;
import edu.cmu.tetrad.search.IndependenceTest;
import edu.cmu.tetrad.search.test.IndependenceResult;
import edu.cmu.tetrad.search.utils.LogUtilsSearch;
import edu.cmu.tetrad.util.RandomUtil;
import edu.cmu.tetrad.util.StatUtils;
import edu.cmu.tetrad.util.TetradLogger;
import java.util.ArrayList;
import java.util.List;
import java.util.Set;
import org.apache.commons.math3.util.FastMath;

public final class IndTestFisherZConcatenateResiduals
implements IndependenceTest {
    private final List<Node> variables;
    private final ArrayList<Regression> regressions;
    private List<DataSet> dataSets;
    private double alpha;
    private double pValue = Double.NaN;
    private boolean verbose;

    public IndTestFisherZConcatenateResiduals(List<DataSet> dataSets, double alpha) {
        System.out.println("# data sets = " + dataSets.size());
        this.dataSets = dataSets;
        this.regressions = new ArrayList();
        for (DataSet dataSet : dataSets) {
            BoxDataSet _dataSet = new BoxDataSet(new DoubleDataBox(dataSet.getDoubleData().toArray()), dataSets.get(0).getVariables());
            this.regressions.add(new RegressionDataset(_dataSet));
        }
        this.setAlpha(alpha);
        this.variables = dataSets.get(0).getVariables();
        ArrayList<DataSet> dataSets2 = new ArrayList<DataSet>();
        for (DataSet set : dataSets) {
            BoxDataSet dataSet = new BoxDataSet(new DoubleDataBox(set.getDoubleData().toArray()), this.variables);
            dataSets2.add(dataSet);
        }
        this.dataSets = dataSets2;
    }

    @Override
    public IndependenceTest indTestSubset(List<Node> vars) {
        throw new UnsupportedOperationException();
    }

    @Override
    public IndependenceResult checkIndependence(Node x, Node y, Set<Node> _z) {
        boolean independent;
        double fisherZ;
        int i;
        x = this.getVariable(this.variables, x.getName());
        List<Node> z = GraphUtils.replaceNodes(new ArrayList<Node>(_z), new ArrayList<Node>(this.variables));
        double[] residualsX = this.residuals(x, z);
        double[] residualsY = this.residuals(y, z);
        ArrayList<Double> residualsXFiltered = new ArrayList<Double>();
        ArrayList<Double> residualsYFiltered = new ArrayList<Double>();
        for (i = 0; i < residualsX.length; ++i) {
            if (Double.isNaN(residualsX[i]) || Double.isNaN(residualsY[i])) continue;
            residualsXFiltered.add(residualsX[i]);
            residualsYFiltered.add(residualsY[i]);
        }
        residualsX = new double[residualsXFiltered.size()];
        residualsY = new double[residualsYFiltered.size()];
        for (i = 0; i < residualsXFiltered.size(); ++i) {
            residualsX[i] = (Double)residualsXFiltered.get(i);
            residualsY[i] = (Double)residualsYFiltered.get(i);
        }
        if (residualsX.length != residualsY.length) {
            throw new IllegalArgumentException("Missing values handled.");
        }
        int sampleSize = residualsX.length;
        double r = StatUtils.correlation(residualsX, residualsY);
        if (r > 1.0) {
            r = 1.0;
        }
        if (r < -1.0) {
            r = -1.0;
        }
        if (Double.isNaN(fisherZ = FastMath.sqrt((double)(sampleSize - z.size()) - 3.0) * 0.5 * (FastMath.log(1.0 + r) - FastMath.log(1.0 - r)))) {
            return new IndependenceResult(new IndependenceFact(x, y, _z), true, Double.NaN, Double.NaN);
        }
        double pValue = 2.0 * (1.0 - RandomUtil.getInstance().normalCdf(0.0, 1.0, FastMath.abs(fisherZ)));
        if (Double.isNaN(pValue)) {
            throw new RuntimeException("Undefined p-value encountered for test: " + LogUtilsSearch.independenceFact(x, y, _z));
        }
        this.pValue = pValue;
        boolean bl = independent = pValue > this.alpha;
        if (this.verbose && independent) {
            TetradLogger.getInstance().forceLogMessage(LogUtilsSearch.independenceFactMsg(x, y, _z, this.pValue));
        }
        return new IndependenceResult(new IndependenceFact(x, y, _z), independent, pValue, pValue - this.getAlpha());
    }

    @Override
    public double getAlpha() {
        return this.alpha;
    }

    @Override
    public void setAlpha(double alpha) {
        if (alpha < 0.0 || alpha > 1.0) {
            throw new IllegalArgumentException("Significance out of range.");
        }
        this.alpha = alpha;
    }

    @Override
    public List<Node> getVariables() {
        return this.variables;
    }

    public boolean determines(List<Node> z, Node x) throws UnsupportedOperationException {
        throw new UnsupportedOperationException();
    }

    @Override
    public DataSet getData() {
        return DataTransforms.concatenate(this.dataSets);
    }

    @Override
    public ICovarianceMatrix getCov() {
        ArrayList<DataSet> _dataSets = new ArrayList<DataSet>();
        for (DataSet d : this.dataSets) {
            _dataSets.add(DataTransforms.standardizeData(d));
        }
        return new CovarianceMatrix(DataTransforms.concatenate(_dataSets));
    }

    @Override
    public String toString() {
        return "Fisher Z, Concatenating Residuals";
    }

    @Override
    public boolean isVerbose() {
        return this.verbose;
    }

    @Override
    public void setVerbose(boolean verbose) {
        this.verbose = verbose;
    }

    private double[] residuals(Node node, List<Node> parents) {
        ArrayList<Double> _residuals = new ArrayList<Double>();
        Node target = this.dataSets.get(0).getVariable(node.getName());
        ArrayList<Node> regressors = new ArrayList<Node>();
        for (Node _regressor : parents) {
            Node variable = this.dataSets.get(0).getVariable(_regressor.getName());
            regressors.add(variable);
        }
        for (int m = 0; m < this.dataSets.size(); ++m) {
            RegressionResult result = this.regressions.get(m).regress(target, regressors);
            double[] residualsSingleDataset = result.getResiduals().toArray();
            double mean = StatUtils.mean(residualsSingleDataset);
            for (int i2 = 0; i2 < residualsSingleDataset.length; ++i2) {
                residualsSingleDataset[i2] = residualsSingleDataset[i2] - mean;
            }
            for (double d : residualsSingleDataset) {
                _residuals.add(d);
            }
        }
        double[] _f = new double[_residuals.size()];
        for (int k = 0; k < _residuals.size(); ++k) {
            _f[k] = (Double)_residuals.get(k);
        }
        return _f;
    }

    private Node getVariable(List<Node> variables, String name) {
        for (Node node : variables) {
            if (!name.equals(node.getName())) continue;
            return node;
        }
        return null;
    }
}

