/*
 * Decompiled with CFR 0.152.
 */
package edu.cmu.tetrad.search.score;

import edu.cmu.tetrad.data.CorrelationMatrix;
import edu.cmu.tetrad.data.DataSet;
import edu.cmu.tetrad.data.ICovarianceMatrix;
import edu.cmu.tetrad.data.SimpleDataLoader;
import edu.cmu.tetrad.graph.Node;
import edu.cmu.tetrad.search.score.Score;
import edu.cmu.tetrad.search.score.SemBicScore;
import edu.cmu.tetrad.search.utils.LogUtilsSearch;
import edu.cmu.tetrad.util.Matrix;
import java.util.ArrayList;
import java.util.List;
import org.apache.commons.math3.linear.SingularMatrixException;
import org.apache.commons.math3.util.FastMath;

public class ZsbScore
implements Score {
    private final List<Node> variables;
    double[] maxScores;
    int[] estMaxParents;
    double[] estMaxVarRys;
    private double riskBound = 0.001;
    private ICovarianceMatrix covariances;
    private int sampleSize;
    private List<Double> lambdas;
    private Matrix data;
    private boolean changed = false;

    public ZsbScore(ICovarianceMatrix covMatrix) {
        if (covMatrix == null) {
            throw new NullPointerException();
        }
        this.setCovariances(covMatrix);
        this.variables = covMatrix.getVariables();
        this.sampleSize = covMatrix.getSampleSize();
    }

    public ZsbScore(DataSet dataSet, boolean precomputeCovariances) {
        this(SimpleDataLoader.getCovarianceMatrix(dataSet, precomputeCovariances));
        this.data = dataSet.getDoubleData();
    }

    private static double zhangShenLambda(int m0, double pn, double riskBound) {
        if ((double)m0 > pn) {
            throw new IllegalArgumentException("m0 should not be > pn; m0 = " + m0 + " pn = " + pn);
        }
        double high = 10000.0;
        double low = 0.0;
        while (high - low > 1.0E-13) {
            double lambda = (high + low) / 2.0;
            double p = ZsbScore.getP(pn, m0, lambda);
            if (p < 1.0 - riskBound) {
                low = lambda;
                continue;
            }
            high = lambda;
        }
        return low;
    }

    private static double getP(double pn, double m0, double lambda) {
        return 2.0 - FastMath.pow(1.0 + FastMath.exp(-(lambda - 1.0) / 2.0) * FastMath.sqrt(lambda), pn - m0);
    }

    @Override
    public double localScore(int i, int ... parents) {
        double varRy;
        int pn = this.variables.size() - 1;
        boolean calculateRowSubsets = false;
        if (this.estMaxParents == null) {
            this.estMaxParents = new int[this.variables.size()];
            this.maxScores = new double[this.variables.size()];
            this.estMaxVarRys = new double[this.variables.size()];
            for (int j = 0; j < this.variables.size(); ++j) {
                this.estMaxParents[j] = 0;
                this.maxScores[j] = Double.NEGATIVE_INFINITY;
                this.estMaxVarRys[j] = Double.NaN;
            }
        }
        int pi = parents.length;
        try {
            varRy = SemBicScore.getVarRy(i, parents, this.data, this.covariances, calculateRowSubsets);
        }
        catch (SingularMatrixException e) {
            throw new RuntimeException("Singularity encountered when scoring " + LogUtilsSearch.getScoreFact(i, parents, this.variables));
        }
        int m0 = this.estMaxParents[i];
        double score = -(0.5 * (double)this.sampleSize * FastMath.log(varRy) + this.getLambda(m0, pn) * (double)pi);
        if (score >= this.maxScores[i]) {
            this.maxScores[i] = score;
            this.estMaxParents[i] = parents.length;
            this.estMaxVarRys[i] = varRy;
        }
        return score;
    }

    @Override
    public double localScoreDiff(int x, int y, int[] z) {
        return this.localScore(y, this.append(z, x)) - this.localScore(y, z);
    }

    public ICovarianceMatrix getCovariances() {
        return this.covariances;
    }

    private void setCovariances(ICovarianceMatrix covariances) {
        CorrelationMatrix correlations = new CorrelationMatrix(covariances);
        this.covariances = covariances;
        boolean exists = false;
        double correlationThreshold = 1.0;
        for (int i = 0; i < correlations.getSize(); ++i) {
            for (int j = 0; j < correlations.getSize(); ++j) {
                double r;
                if (i == j || !(FastMath.abs(r = correlations.getValue(i, j)) > correlationThreshold)) continue;
                System.out.println("Absolute correlation too high: " + r);
                exists = true;
            }
        }
        if (exists) {
            throw new IllegalArgumentException("Some correlations are too high (> " + correlationThreshold + ") in absolute value.");
        }
        this.sampleSize = covariances.getSampleSize();
    }

    @Override
    public int getSampleSize() {
        return this.sampleSize;
    }

    @Override
    public boolean isEffectEdge(double bump) {
        return bump > 0.0;
    }

    @Override
    public List<Node> getVariables() {
        return new ArrayList<Node>(this.variables);
    }

    @Override
    public int getMaxDegree() {
        return (int)FastMath.ceil(FastMath.log(this.sampleSize));
    }

    @Override
    public boolean determines(List<Node> z, Node y) {
        int i = this.variables.indexOf(y);
        int[] k = this.indices(z);
        double v = this.localScore(i, k);
        return Double.isNaN(v);
    }

    public void setRiskBound(double riskBound) {
        this.riskBound = riskBound;
    }

    private double getLambda(int m0, int pn) {
        if (this.lambdas == null) {
            this.lambdas = new ArrayList<Double>();
        }
        if (this.lambdas.size() - 1 < m0) {
            for (int t = this.lambdas.size(); t <= m0; ++t) {
                double lambda = ZsbScore.zhangShenLambda(t, pn, this.riskBound);
                this.lambdas.add(lambda);
            }
        }
        return this.lambdas.get(m0);
    }

    private int[] indices(List<Node> __adj) {
        int[] indices = new int[__adj.size()];
        for (int t = 0; t < __adj.size(); ++t) {
            indices[t] = this.variables.indexOf(__adj.get(t));
        }
        return indices;
    }
}

