/*
 * Decompiled with CFR 0.152.
 */
package edu.cmu.tetrad.search;

import edu.cmu.tetrad.data.CorrelationMatrix;
import edu.cmu.tetrad.data.DataSet;
import edu.cmu.tetrad.data.ICovarianceMatrix;
import edu.cmu.tetrad.graph.IndependenceFact;
import edu.cmu.tetrad.graph.Node;
import edu.cmu.tetrad.search.Boss;
import edu.cmu.tetrad.search.IndependenceResult;
import edu.cmu.tetrad.search.IndependenceTest;
import edu.cmu.tetrad.search.SearchLogUtils;
import edu.cmu.tetrad.search.SemBicScore;
import edu.cmu.tetrad.search.TeyssierScorer;
import edu.cmu.tetrad.util.Matrix;
import edu.cmu.tetrad.util.MatrixUtils;
import edu.cmu.tetrad.util.StatUtils;
import java.text.DecimalFormat;
import java.util.ArrayList;
import java.util.Collections;
import java.util.HashMap;
import java.util.List;
import java.util.Map;
import java.util.concurrent.ConcurrentHashMap;
import org.apache.commons.math3.distribution.NormalDistribution;
import org.apache.commons.math3.linear.SingularMatrixException;
import org.apache.commons.math3.util.FastMath;

public final class IndTestTeyssier
implements IndependenceTest {
    private final Map<Node, Integer> indexMap;
    private final Map<String, Node> nameMap;
    private final NormalDistribution normal = new NormalDistribution(0.0, 1.0, 1.0E-15);
    private final Map<Node, Integer> nodesHash;
    private final ICovarianceMatrix cor;
    private double penaltyDiscount = 0.5;
    private TeyssierScorer scorer;
    private List<Node> variables;
    private double alpha;
    private DataSet dataSet;
    private boolean verbose = true;
    private double p = Double.NaN;
    private double r = Double.NaN;
    private Boss boss;

    public IndTestTeyssier(DataSet dataSet, double penaltyDiscount) {
        this.dataSet = dataSet;
        if (!dataSet.isContinuous()) {
            throw new IllegalArgumentException("Data set must be continuous.");
        }
        if (!dataSet.existsMissingValue()) {
            this.cor = new CorrelationMatrix(dataSet);
            this.variables = this.cor.getVariables();
            this.indexMap = this.indexMap(this.variables);
            this.nameMap = this.nameMap(this.variables);
            this.setAlpha(this.alpha);
            HashMap<Node, Integer> nodesHash = new HashMap<Node, Integer>();
            for (int j = 0; j < this.variables.size(); ++j) {
                nodesHash.put(this.variables.get(j), j);
            }
            this.nodesHash = nodesHash;
        } else {
            this.cor = new CorrelationMatrix(dataSet);
            if (!(this.alpha >= 0.0) || !(this.alpha <= 1.0)) {
                throw new IllegalArgumentException("Alpha mut be in [0, 1]");
            }
            List<Node> nodes = dataSet.getVariables();
            this.variables = Collections.unmodifiableList(nodes);
            this.indexMap = this.indexMap(this.variables);
            this.nameMap = this.nameMap(this.variables);
            this.setAlpha(this.alpha);
            HashMap<Node, Integer> nodesHash = new HashMap<Node, Integer>();
            for (int j = 0; j < this.variables.size(); ++j) {
                nodesHash.put(this.variables.get(j), j);
            }
            this.nodesHash = nodesHash;
        }
        SemBicScore score = new SemBicScore(dataSet);
        score.setPenaltyDiscount(penaltyDiscount);
        this.scorer = new TeyssierScorer(null, score);
        this.scorer.score(this.variables);
        this.boss = new Boss(this.scorer);
    }

    public IndTestTeyssier(ICovarianceMatrix covMatrix, double penaltyDiscount) {
        this.cor = new CorrelationMatrix(covMatrix);
        this.variables = covMatrix.getVariables();
        this.indexMap = this.indexMap(this.variables);
        this.nameMap = this.nameMap(this.variables);
        HashMap<Node, Integer> nodesHash = new HashMap<Node, Integer>();
        for (int j = 0; j < this.variables.size(); ++j) {
            nodesHash.put(this.variables.get(j), j);
        }
        this.nodesHash = nodesHash;
        SemBicScore score = new SemBicScore(covMatrix);
        score.setPenaltyDiscount(penaltyDiscount);
        this.scorer = new TeyssierScorer(null, score);
        this.scorer.score(this.variables);
        this.boss = new Boss(this.scorer);
    }

    @Override
    public IndependenceTest indTestSubset(List<Node> vars) {
        if (vars.isEmpty()) {
            throw new IllegalArgumentException("Subset may not be empty.");
        }
        for (Node var : vars) {
            if (this.variables.contains(var)) continue;
            throw new IllegalArgumentException("All vars must be original vars");
        }
        int[] indices = new int[vars.size()];
        for (int i = 0; i < indices.length; ++i) {
            indices[i] = this.indexMap.get(vars.get(i));
        }
        ICovarianceMatrix newCovMatrix = this.cor.getSubmatrix(indices);
        double alphaNew = this.getAlpha();
        return new IndTestTeyssier(newCovMatrix, alphaNew);
    }

    @Override
    public IndependenceResult checkIndependence(Node x, Node y, List<Node> z) {
        ArrayList<Node> perm = new ArrayList<Node>(z);
        perm.add(x);
        perm.add(y);
        this.boss.bestOrder(perm);
        boolean independent = this.scorer.adjacent(x, y);
        return new IndependenceResult(new IndependenceFact(x, y, z), independent, independent ? 0.0 : 1.0);
    }

    public double getPValue() {
        return this.p;
    }

    public double getPValue(Node x, Node y, List<Node> z) throws SingularMatrixException {
        double p;
        int n;
        double r;
        if (this.covMatrix() != null) {
            r = this.partialCorrelation(x, y, z, null);
            n = this.sampleSize();
        } else {
            ArrayList<Node> allVars = new ArrayList<Node>(z);
            allVars.add(x);
            allVars.add(y);
            List<Integer> rows = this.getRows(allVars, this.nodesHash);
            r = this.getR(x, y, z, rows);
            n = rows.size();
        }
        this.r = r;
        double q = 0.5 * (StrictMath.log(1.0 + FastMath.abs(r)) - StrictMath.log(1.0 - FastMath.abs(r)));
        double fisherZ = FastMath.sqrt((double)n - 3.0 - (double)z.size()) * q;
        this.p = p = 2.0 * (1.0 - this.normal.cumulativeProbability(fisherZ));
        return p;
    }

    private double partialCorrelation(Node x, Node y, List<Node> z, List<Integer> rows) throws SingularMatrixException {
        Matrix cor;
        int[] indices = new int[z.size() + 2];
        indices[0] = this.indexMap.get(x);
        indices[1] = this.indexMap.get(y);
        for (int i = 0; i < z.size(); ++i) {
            indices[i + 2] = this.indexMap.get(z.get(i));
        }
        if (this.cor != null) {
            cor = this.cor.getSelection(indices, indices);
        } else {
            Matrix cov = this.getCov(rows, indices);
            cor = MatrixUtils.convertCovToCorr(cov);
        }
        return StatUtils.partialCorrelationPrecisionMatrix(cor);
    }

    private Matrix getCov(List<Integer> rows, int[] cols) {
        Matrix cov = new Matrix(cols.length, cols.length);
        for (int i = 0; i < cols.length; ++i) {
            for (int j = 0; j < cols.length; ++j) {
                double mui = 0.0;
                double muj = 0.0;
                for (int k : rows) {
                    mui += this.dataSet.getDouble(k, cols[i]);
                    muj += this.dataSet.getDouble(k, cols[j]);
                }
                mui /= (double)(rows.size() - 1);
                muj /= (double)(rows.size() - 1);
                double _cov = 0.0;
                for (int k : rows) {
                    _cov += (this.dataSet.getDouble(k, cols[i]) - mui) * (this.dataSet.getDouble(k, cols[j]) - muj);
                }
                double mean = _cov / (double)rows.size();
                cov.set(i, j, mean);
            }
        }
        return cov;
    }

    private double getR(Node x, Node y, List<Node> z, List<Integer> rows) {
        return this.partialCorrelation(x, y, z, rows);
    }

    public double getBic() {
        return (double)(-this.sampleSize()) * FastMath.log(1.0 - this.r * this.r) - FastMath.log(this.sampleSize());
    }

    @Override
    public double getAlpha() {
        return this.alpha;
    }

    @Override
    public void setAlpha(double alpha) {
        if (alpha < 0.0 || alpha > 1.0) {
            throw new IllegalArgumentException("Significance out of range: " + alpha);
        }
        this.alpha = alpha;
    }

    @Override
    public List<Node> getVariables() {
        return this.variables;
    }

    public void setVariables(List<Node> variables) {
        if (variables.size() != this.variables.size()) {
            throw new IllegalArgumentException("Wrong # of variables.");
        }
        this.variables = new ArrayList<Node>(variables);
        this.cor.setVariables(variables);
    }

    @Override
    public Node getVariable(String name) {
        return this.nameMap.get(name);
    }

    @Override
    public List<String> getVariableNames() {
        List<Node> variables = this.getVariables();
        ArrayList<String> variableNames = new ArrayList<String>();
        for (Node variable1 : variables) {
            variableNames.add(variable1.getName());
        }
        return variableNames;
    }

    @Override
    public boolean determines(List<Node> z, Node x) throws UnsupportedOperationException {
        int[] parents = new int[z.size()];
        for (int j = 0; j < parents.length; ++j) {
            parents[j] = this.cor.getVariables().indexOf(z.get(j));
        }
        if (parents.length > 0) {
            Matrix Czz = this.cor.getSelection(parents, parents);
            try {
                Czz.inverse();
            }
            catch (SingularMatrixException e) {
                System.out.println(SearchLogUtils.determinismDetected(z, x));
                return true;
            }
        }
        return false;
    }

    @Override
    public DataSet getData() {
        return this.dataSet;
    }

    @Override
    public String toString() {
        return "Fisher Z, alpha = " + new DecimalFormat("0.0###").format(this.getAlpha());
    }

    private int sampleSize() {
        return this.covMatrix().getSampleSize();
    }

    private ICovarianceMatrix covMatrix() {
        return this.cor;
    }

    private Map<String, Node> nameMap(List<Node> variables) {
        ConcurrentHashMap<String, Node> nameMap = new ConcurrentHashMap<String, Node>();
        for (Node node : variables) {
            nameMap.put(node.getName(), node);
        }
        return nameMap;
    }

    private Map<Node, Integer> indexMap(List<Node> variables) {
        ConcurrentHashMap<Node, Integer> indexMap = new ConcurrentHashMap<Node, Integer>();
        for (int i = 0; i < variables.size(); ++i) {
            indexMap.put(variables.get(i), i);
        }
        return indexMap;
    }

    @Override
    public ICovarianceMatrix getCov() {
        return this.cor;
    }

    @Override
    public List<DataSet> getDataSets() {
        ArrayList<DataSet> dataSets = new ArrayList<DataSet>();
        dataSets.add(this.dataSet);
        return dataSets;
    }

    @Override
    public int getSampleSize() {
        return this.cor.getSampleSize();
    }

    @Override
    public List<Matrix> getCovMatrices() {
        return null;
    }

    @Override
    public double getScore() {
        return this.alpha - this.p;
    }

    @Override
    public boolean isVerbose() {
        return this.verbose;
    }

    @Override
    public void setVerbose(boolean verbose) {
        this.verbose = verbose;
    }

    private List<Integer> getRows(List<Node> allVars, Map<Node, Integer> nodesHash) {
        ArrayList<Integer> rows = new ArrayList<Integer>();
        block0: for (int k = 0; k < this.dataSet.getNumRows(); ++k) {
            for (Node node : allVars) {
                if (!Double.isNaN(this.dataSet.getDouble(k, nodesHash.get(node)))) continue;
                continue block0;
            }
            rows.add(k);
        }
        return rows;
    }
}

