/*
 * Decompiled with CFR 0.152.
 */
package edu.cmu.tetrad.search.test;

import edu.cmu.tetrad.data.BoxDataSet;
import edu.cmu.tetrad.data.CorrelationMatrix;
import edu.cmu.tetrad.data.DataSet;
import edu.cmu.tetrad.data.ICovarianceMatrix;
import edu.cmu.tetrad.data.SimpleDataLoader;
import edu.cmu.tetrad.data.VerticalDoubleDataBox;
import edu.cmu.tetrad.graph.IndependenceFact;
import edu.cmu.tetrad.graph.Node;
import edu.cmu.tetrad.search.IndependenceTest;
import edu.cmu.tetrad.search.test.IndependenceResult;
import edu.cmu.tetrad.search.test.RowsSettable;
import edu.cmu.tetrad.search.utils.LogUtilsSearch;
import edu.cmu.tetrad.util.Matrix;
import edu.cmu.tetrad.util.MatrixUtils;
import edu.cmu.tetrad.util.StatUtils;
import edu.cmu.tetrad.util.TetradLogger;
import java.text.DecimalFormat;
import java.util.ArrayList;
import java.util.Collections;
import java.util.HashMap;
import java.util.HashSet;
import java.util.List;
import java.util.Map;
import java.util.Set;
import java.util.concurrent.ConcurrentHashMap;
import org.apache.commons.math3.distribution.NormalDistribution;
import org.apache.commons.math3.linear.SingularMatrixException;
import org.apache.commons.math3.util.FastMath;

public final class IndTestFisherZ
implements IndependenceTest,
RowsSettable {
    private final Map<String, Integer> indexMap;
    private final Map<String, Node> nameMap;
    private final NormalDistribution normal = new NormalDistribution(0.0, 1.0, 1.0E-15);
    private final Map<Node, Integer> nodesHash;
    private final int sampleSize;
    private ICovarianceMatrix cor = null;
    private List<Node> variables;
    private double alpha;
    private DataSet dataSet;
    private boolean verbose = true;
    private double r = Double.NaN;
    private List<Integer> rows = null;

    public IndTestFisherZ(DataSet dataSet, double alpha) {
        this.dataSet = dataSet.copy();
        if (!dataSet.isContinuous()) {
            throw new IllegalArgumentException("Data set must be continuous.");
        }
        if (!dataSet.existsMissingValue()) {
            this.cor = new CorrelationMatrix(dataSet);
            this.variables = this.cor.getVariables();
            this.indexMap = this.indexMap(this.variables);
            this.nameMap = this.nameMap(this.variables);
            this.setAlpha(alpha);
            HashMap<Node, Integer> nodesHash = new HashMap<Node, Integer>();
            for (int j = 0; j < this.variables.size(); ++j) {
                nodesHash.put(this.variables.get(j), j);
            }
            this.nodesHash = nodesHash;
        } else {
            if (!(alpha >= 0.0) || !(alpha <= 1.0)) {
                throw new IllegalArgumentException("Alpha mut be in [0, 1]");
            }
            List<Node> nodes = dataSet.getVariables();
            this.variables = Collections.unmodifiableList(nodes);
            this.indexMap = this.indexMap(this.variables);
            this.nameMap = this.nameMap(this.variables);
            this.setAlpha(alpha);
            HashMap<Node, Integer> nodesHash = new HashMap<Node, Integer>();
            for (int j = 0; j < this.variables.size(); ++j) {
                nodesHash.put(this.variables.get(j), j);
            }
            this.nodesHash = nodesHash;
        }
        this.sampleSize = dataSet.getNumRows();
    }

    public IndTestFisherZ(Matrix data, List<Node> variables, double alpha) {
        this.dataSet = new BoxDataSet(new VerticalDoubleDataBox(data.transpose().toArray()), variables);
        this.cor = SimpleDataLoader.getCorrelationMatrix(this.dataSet);
        this.variables = Collections.unmodifiableList(variables);
        this.indexMap = this.indexMap(variables);
        this.nameMap = this.nameMap(variables);
        this.setAlpha(alpha);
        HashMap<Node, Integer> nodesHash = new HashMap<Node, Integer>();
        for (int j = 0; j < variables.size(); ++j) {
            nodesHash.put(variables.get(j), j);
        }
        this.nodesHash = nodesHash;
        this.sampleSize = this.dataSet.getNumRows();
    }

    public IndTestFisherZ(ICovarianceMatrix covMatrix, double alpha) {
        this.cor = new CorrelationMatrix(covMatrix);
        this.variables = covMatrix.getVariables();
        this.indexMap = this.indexMap(this.variables);
        this.nameMap = this.nameMap(this.variables);
        this.setAlpha(alpha);
        HashMap<Node, Integer> nodesHash = new HashMap<Node, Integer>();
        for (int j = 0; j < this.variables.size(); ++j) {
            nodesHash.put(this.variables.get(j), j);
        }
        this.nodesHash = nodesHash;
        this.sampleSize = this.cor.getSampleSize();
    }

    @Override
    public IndependenceTest indTestSubset(List<Node> vars) {
        if (vars.isEmpty()) {
            throw new IllegalArgumentException("Subset may not be empty.");
        }
        for (Node var : vars) {
            if (this.variables.contains(var)) continue;
            throw new IllegalArgumentException("All vars must be original vars");
        }
        int[] indices = new int[vars.size()];
        for (int i = 0; i < indices.length; ++i) {
            indices[i] = this.indexMap.get(vars.get(i).getName());
        }
        ICovarianceMatrix newCovMatrix = this.cor.getSubmatrix(indices);
        double alphaNew = this.getAlpha();
        return new IndTestFisherZ(newCovMatrix, alphaNew);
    }

    @Override
    public IndependenceResult checkIndependence(Node x, Node y, Set<Node> z) {
        boolean independent;
        double p = Double.NaN;
        try {
            p = this.getPValue(x, y, z);
        }
        catch (SingularMatrixException e) {
            throw new RuntimeException("Singular matrix encountered for test: " + LogUtilsSearch.independenceFact(x, y, z));
        }
        boolean bl = independent = p > this.alpha;
        if (this.verbose && independent) {
            TetradLogger.getInstance().forceLogMessage(LogUtilsSearch.independenceFactMsg(x, y, z, p));
        }
        if (Double.isNaN(p)) {
            throw new RuntimeException("Undefined p-value encountered in for test: " + LogUtilsSearch.independenceFact(x, y, z));
        }
        return new IndependenceResult(new IndependenceFact(x, y, z), independent, p, this.alpha - p);
    }

    private double getPValue(Node x, Node y, Set<Node> z) throws SingularMatrixException {
        int n;
        double r;
        if (this.covMatrix() != null) {
            r = this.partialCorrelation(x, y, z, null);
            n = this.sampleSize();
        } else {
            ArrayList<Node> allVars = new ArrayList<Node>(z);
            allVars.add(x);
            allVars.add(y);
            List<Integer> rows = this.getRows(allVars, this.nodesHash);
            r = this.getR(x, y, z, rows);
            n = rows.size();
        }
        this.r = r;
        double q = 0.5 * (StrictMath.log(1.0 + FastMath.abs(r)) - StrictMath.log(1.0 - FastMath.abs(r)));
        double fisherZ = FastMath.sqrt((double)n - 3.0 - (double)z.size()) * q;
        return 2.0 * (1.0 - this.normal.cumulativeProbability(fisherZ));
    }

    public double getBic() {
        return (double)(-this.sampleSize()) * FastMath.log(1.0 - this.r * this.r) - FastMath.log(this.sampleSize());
    }

    @Override
    public double getAlpha() {
        return this.alpha;
    }

    @Override
    public void setAlpha(double alpha) {
        if (alpha < 0.0 || alpha > 1.0) {
            throw new IllegalArgumentException("Significance out of range: " + alpha);
        }
        this.alpha = alpha;
    }

    @Override
    public List<Node> getVariables() {
        return this.variables;
    }

    public void setVariables(List<Node> variables) {
        if (variables.size() != this.variables.size()) {
            throw new IllegalArgumentException("Wrong # of variables.");
        }
        this.variables = new ArrayList<Node>(variables);
        this.cor.setVariables(variables);
    }

    @Override
    public Node getVariable(String name) {
        return this.nameMap.get(name);
    }

    @Override
    public DataSet getData() {
        return this.dataSet;
    }

    @Override
    public ICovarianceMatrix getCov() {
        return this.cor;
    }

    @Override
    public List<DataSet> getDataSets() {
        ArrayList<DataSet> dataSets = new ArrayList<DataSet>();
        dataSets.add(this.dataSet);
        return dataSets;
    }

    @Override
    public int getSampleSize() {
        return this.sampleSize();
    }

    @Override
    public boolean isVerbose() {
        return this.verbose;
    }

    @Override
    public void setVerbose(boolean verbose) {
        this.verbose = verbose;
    }

    @Override
    public String toString() {
        return "Fisher Z, alpha = " + new DecimalFormat("0.0###").format(this.getAlpha());
    }

    public boolean determines(List<Node> z, Node x) throws UnsupportedOperationException {
        int[] parents = new int[z.size()];
        for (int j = 0; j < parents.length; ++j) {
            parents[j] = this.indexMap.get(z.get(j).getName());
        }
        if (parents.length > 0) {
            Matrix Czz = this.cor.getSelection(parents, parents);
            try {
                Czz.inverse();
            }
            catch (SingularMatrixException e) {
                System.out.println(LogUtilsSearch.determinismDetected(new HashSet<Node>(z), x));
                return true;
            }
        }
        return false;
    }

    private double partialCorrelation(Node x, Node y, Set<Node> _z, List<Integer> rows) throws SingularMatrixException {
        Matrix cor;
        ArrayList<Node> z = new ArrayList<Node>(_z);
        Collections.sort(z);
        int[] indices = new int[z.size() + 2];
        indices[0] = this.indexMap.get(x.getName());
        indices[1] = this.indexMap.get(y.getName());
        for (int i = 0; i < z.size(); ++i) {
            indices[i + 2] = this.indexMap.get(((Node)z.get(i)).getName());
        }
        if (this.cor != null) {
            cor = this.cor.getSelection(indices, indices);
        } else {
            Matrix cov = this.getCov(rows, indices);
            cor = MatrixUtils.convertCovToCorr(cov);
        }
        return StatUtils.partialCorrelationPrecisionMatrix(cor);
    }

    private Matrix getCov(List<Integer> rows, int[] cols) {
        Matrix cov = new Matrix(cols.length, cols.length);
        for (int i = 0; i < cols.length; ++i) {
            for (int j = 0; j < cols.length; ++j) {
                double mui = 0.0;
                double muj = 0.0;
                for (int k : rows) {
                    mui += this.dataSet.getDouble(k, cols[i]);
                    muj += this.dataSet.getDouble(k, cols[j]);
                }
                mui /= (double)(rows.size() - 1);
                muj /= (double)(rows.size() - 1);
                double _cov = 0.0;
                for (int k : rows) {
                    _cov += (this.dataSet.getDouble(k, cols[i]) - mui) * (this.dataSet.getDouble(k, cols[j]) - muj);
                }
                double mean = _cov / (double)rows.size();
                cov.set(i, j, mean);
            }
        }
        return cov;
    }

    private double getR(Node x, Node y, Set<Node> z, List<Integer> rows) {
        return this.partialCorrelation(x, y, z, rows);
    }

    private int sampleSize() {
        return this.sampleSize;
    }

    private ICovarianceMatrix covMatrix() {
        return this.cor;
    }

    private Map<String, Node> nameMap(List<Node> variables) {
        ConcurrentHashMap<String, Node> nameMap = new ConcurrentHashMap<String, Node>();
        for (Node node : variables) {
            nameMap.put(node.getName(), node);
        }
        return nameMap;
    }

    private Map<String, Integer> indexMap(List<Node> variables) {
        HashMap<String, Integer> indexMap = new HashMap<String, Integer>();
        for (int i = 0; i < variables.size(); ++i) {
            indexMap.put(variables.get(i).getName(), i);
        }
        return indexMap;
    }

    private List<Integer> getRows(List<Node> allVars, Map<Node, Integer> nodesHash) {
        if (this.rows != null) {
            return this.rows;
        }
        ArrayList<Integer> rows = new ArrayList<Integer>();
        block0: for (int k = 0; k < this.dataSet.getNumRows(); ++k) {
            for (Node node : allVars) {
                if (!Double.isNaN(this.dataSet.getDouble(k, nodesHash.get(node)))) continue;
                continue block0;
            }
            rows.add(k);
        }
        return rows;
    }

    @Override
    public List<Integer> getRows() {
        return this.rows;
    }

    @Override
    public void setRows(List<Integer> rows) {
        if (rows != null) {
            for (int i = 0; i < rows.size(); ++i) {
                if (rows.get(i) >= 0 && rows.get(i) <= this.sampleSize()) continue;
                throw new IllegalArgumentException("Row index = " + i + "=" + rows.get(i) + " is out of bounds.");
            }
            this.rows = rows;
            this.cor = null;
        } else {
            this.cor = new CorrelationMatrix(this.dataSet);
            this.rows = null;
        }
    }
}

