/*
 * Decompiled with CFR 0.152.
 */
package edu.cmu.tetrad.search;

import edu.cmu.tetrad.data.CorrelationMatrix;
import edu.cmu.tetrad.data.DataSet;
import edu.cmu.tetrad.data.ICovarianceMatrix;
import edu.cmu.tetrad.data.SimpleDataLoader;
import edu.cmu.tetrad.graph.Node;
import edu.cmu.tetrad.search.Score;
import edu.cmu.tetrad.search.SemBicScore;
import edu.cmu.tetrad.util.Matrix;
import java.util.ArrayList;
import java.util.Arrays;
import java.util.List;
import org.apache.commons.math3.util.FastMath;

public class ZhangShenBoundScore
implements Score {
    private final List<Node> variables;
    private double riskBound = 0.001;
    double[] maxScores;
    int[] estMaxParents;
    double[] estMaxVarRys;
    private ICovarianceMatrix covariances;
    private int sampleSize;
    private boolean verbose = false;
    private List<Double> lambdas;
    private Matrix data;
    private boolean changed = false;

    public ZhangShenBoundScore(ICovarianceMatrix covariances) {
        if (covariances == null) {
            throw new NullPointerException();
        }
        this.setCovariances(covariances);
        this.variables = covariances.getVariables();
        this.sampleSize = covariances.getSampleSize();
    }

    public ZhangShenBoundScore(DataSet dataSet) {
        this(SimpleDataLoader.getCovarianceMatrix(dataSet));
        this.data = dataSet.getDoubleData();
    }

    public static double zhangShenLambda(int m0, double pn, double riskBound) {
        if ((double)m0 > pn) {
            throw new IllegalArgumentException("m0 should not be > pn; m0 = " + m0 + " pn = " + pn);
        }
        double high = 10000.0;
        double low = 0.0;
        while (high - low > 1.0E-13) {
            double lambda = (high + low) / 2.0;
            double p = ZhangShenBoundScore.getP(pn, m0, lambda);
            if (p < 1.0 - riskBound) {
                low = lambda;
                continue;
            }
            high = lambda;
        }
        return low;
    }

    public static double getP(double pn, double m0, double lambda) {
        return 2.0 - FastMath.pow(1.0 + FastMath.exp(-(lambda - 1.0) / 2.0) * FastMath.sqrt(lambda), pn - m0);
    }

    private static int[] append(int[] z, int x) {
        int[] _z = Arrays.copyOf(z, z.length + 1);
        _z[z.length] = x;
        return _z;
    }

    private int[] indices(List<Node> __adj) {
        int[] indices = new int[__adj.size()];
        for (int t = 0; t < __adj.size(); ++t) {
            indices[t] = this.variables.indexOf(__adj.get(t));
        }
        return indices;
    }

    @Override
    public double localScoreDiff(int x, int y, int[] z) {
        return this.localScore(y, ZhangShenBoundScore.append(z, x)) - this.localScore(y, z);
    }

    @Override
    public double localScoreDiff(int x, int y) {
        return this.localScoreDiff(x, y, new int[0]);
    }

    @Override
    public double localScore(int i, int ... parents) throws RuntimeException {
        int pn = this.variables.size() - 1;
        boolean calculateRowSubsets = false;
        if (this.estMaxParents == null) {
            this.estMaxParents = new int[this.variables.size()];
            this.maxScores = new double[this.variables.size()];
            this.estMaxVarRys = new double[this.variables.size()];
            for (int j = 0; j < this.variables.size(); ++j) {
                this.estMaxParents[j] = 0;
                this.maxScores[j] = Double.NEGATIVE_INFINITY;
                this.estMaxVarRys[j] = Double.NaN;
            }
        }
        int pi = parents.length;
        double varRy = SemBicScore.getVarRy(i, parents, this.data, this.covariances, calculateRowSubsets);
        int m0 = this.estMaxParents[i];
        double score = -(0.5 * (double)this.sampleSize * FastMath.log(varRy) + this.getLambda(m0, pn) * (double)pi);
        if (score >= this.maxScores[i]) {
            this.maxScores[i] = score;
            this.estMaxParents[i] = parents.length;
            this.estMaxVarRys[i] = varRy;
        }
        return score;
    }

    private double getLambda(int m0, int pn) {
        if (this.lambdas == null) {
            this.lambdas = new ArrayList<Double>();
        }
        if (this.lambdas.size() - 1 < m0) {
            for (int t = this.lambdas.size(); t <= m0; ++t) {
                double lambda = ZhangShenBoundScore.zhangShenLambda(t, pn, this.riskBound);
                this.lambdas.add(lambda);
            }
        }
        return this.lambdas.get(m0);
    }

    @Override
    public double localScore(int i, int parent) {
        return this.localScore(i, new int[]{parent});
    }

    @Override
    public double localScore(int i) {
        return this.localScore(i, new int[0]);
    }

    public ICovarianceMatrix getCovariances() {
        return this.covariances;
    }

    private void setCovariances(ICovarianceMatrix covariances) {
        CorrelationMatrix correlations = new CorrelationMatrix(covariances);
        this.covariances = covariances;
        boolean exists = false;
        double correlationThreshold = 1.0;
        for (int i = 0; i < correlations.getSize(); ++i) {
            for (int j = 0; j < correlations.getSize(); ++j) {
                double r;
                if (i == j || !(FastMath.abs(r = correlations.getValue(i, j)) > correlationThreshold)) continue;
                System.out.println("Absolute correlation too high: " + r);
                exists = true;
            }
        }
        if (exists) {
            throw new IllegalArgumentException("Some correlations are too high (> " + correlationThreshold + ") in absolute value.");
        }
        this.sampleSize = covariances.getSampleSize();
    }

    @Override
    public int getSampleSize() {
        return this.sampleSize;
    }

    @Override
    public boolean isEffectEdge(double bump) {
        return bump > 0.0;
    }

    public boolean isVerbose() {
        return this.verbose;
    }

    public void setVerbose(boolean verbose) {
        this.verbose = verbose;
    }

    @Override
    public List<Node> getVariables() {
        return new ArrayList<Node>(this.variables);
    }

    @Override
    public Node getVariable(String targetName) {
        for (Node node : this.variables) {
            if (!node.getName().equals(targetName)) continue;
            return node;
        }
        return null;
    }

    @Override
    public int getMaxDegree() {
        return (int)FastMath.ceil(FastMath.log(this.sampleSize));
    }

    @Override
    public boolean determines(List<Node> z, Node y) {
        int i = this.variables.indexOf(y);
        int[] k = this.indices(z);
        double v = this.localScore(i, k);
        return Double.isNaN(v);
    }

    public boolean isChanged() {
        return this.changed;
    }

    public void setChanged(boolean b) {
        this.changed = b;
    }

    public void setRiskBound(double riskBound) {
        this.riskBound = riskBound;
    }
}

