/*
 * Decompiled with CFR 0.152.
 */
package edu.cmu.tetrad.search;

import edu.cmu.tetrad.data.CorrelationMatrix;
import edu.cmu.tetrad.data.CovarianceMatrix;
import edu.cmu.tetrad.data.DataSet;
import edu.cmu.tetrad.data.DataUtils;
import edu.cmu.tetrad.data.ICovarianceMatrix;
import edu.cmu.tetrad.graph.Node;
import edu.cmu.tetrad.search.Score;
import edu.cmu.tetrad.util.Matrix;
import java.util.ArrayList;
import java.util.Arrays;
import java.util.List;
import org.apache.commons.math3.linear.SingularMatrixException;
import org.apache.commons.math3.util.FastMath;
import org.jetbrains.annotations.NotNull;

public class ZhangShenBoundTest
implements Score {
    private ICovarianceMatrix covariances;
    private final List<Node> variables;
    private final int sampleSize;
    private boolean verbose = false;
    private double N;
    private List<Double> lambdas;
    private Matrix data;
    private boolean calculateSquaredEuclideanNorms = false;
    private boolean calculateRowSubsets = false;
    double[] maxScores;
    int[] estMinParents;
    double[] estVarRys;
    private boolean changed = false;
    private double correlationThreshold = 1.0;
    private double penaltyDiscount = 1.0;
    private boolean takeLog = true;
    private double riskBound = 0.0;
    private double trueErrorVariance;

    public ZhangShenBoundTest(ICovarianceMatrix covariances) {
        if (covariances == null) {
            throw new NullPointerException();
        }
        this.setCovariances(covariances);
        this.variables = covariances.getVariables();
        this.sampleSize = covariances.getSampleSize();
        this.estMinParents = new int[this.variables.size()];
        this.maxScores = new double[this.variables.size()];
        this.estVarRys = new double[this.variables.size()];
    }

    public ZhangShenBoundTest(DataSet dataSet) {
        if (dataSet == null) {
            throw new NullPointerException();
        }
        this.variables = dataSet.getVariables();
        this.sampleSize = dataSet.getNumRows();
        DataSet _dataSet = DataUtils.center(dataSet);
        this.data = _dataSet.getDoubleData();
        if (!dataSet.existsMissingValue()) {
            this.setCovariances(new CovarianceMatrix(dataSet));
            this.calculateRowSubsets = false;
        } else {
            this.calculateRowSubsets = true;
        }
    }

    private int[] indices(List<Node> __adj) {
        int[] indices = new int[__adj.size()];
        for (int t = 0; t < __adj.size(); ++t) {
            indices[t] = this.variables.indexOf(__adj.get(t));
        }
        return indices;
    }

    @Override
    public double localScoreDiff(int x, int y, int[] z) {
        return this.localScore(y, ZhangShenBoundTest.append(z, x)) - this.localScore(y, z);
    }

    @Override
    public double localScoreDiff(int x, int y) {
        return this.localScoreDiff(x, y, new int[0]);
    }

    @Override
    public double localScore(int i, int ... parents) throws RuntimeException {
        if (this.estMinParents == null) {
            this.estMinParents = new int[this.variables.size()];
            this.maxScores = new double[this.variables.size()];
            this.estVarRys = new double[this.variables.size()];
            for (int j = 0; j < this.variables.size(); ++j) {
                this.estMinParents[j] = 0;
                this.maxScores[j] = this.localScore(j, new int[0]);
                this.estVarRys[j] = ZhangShenBoundTest.getVarRy(j, new int[0], this.data, this.covariances, this.calculateRowSubsets, this.calculateSquaredEuclideanNorms);
            }
        }
        int pi = parents.length + 1;
        double varRy = ZhangShenBoundTest.getVarRy(i, parents, this.data, this.covariances, this.calculateRowSubsets, this.calculateSquaredEuclideanNorms);
        double score = this.takeLog ? -(this.N * FastMath.log(varRy) + this.getLambda(this.estMinParents[i]) * (double)pi * 2.0) : -(this.N * varRy + this.getLambda(this.estMinParents[i]) * (double)pi * this.estVarRys[i]);
        if (score > this.maxScores[i]) {
            this.estMinParents[i] = parents.length;
            this.estVarRys[i] = varRy;
            this.maxScores[i] = score;
            this.changed = true;
            System.out.println(Arrays.toString(this.estVarRys));
        }
        return score;
    }

    public static double getVarRy(int i, int[] parents, Matrix data, ICovarianceMatrix covariances, boolean calculateRowSubsets, boolean calculateSquareEuclideanNorms) {
        if (calculateSquareEuclideanNorms) {
            return 1.0 / (double)data.rows() * ZhangShenBoundTest.getSquaredEucleanNorm(i, parents, data);
        }
        try {
            int[] all = ZhangShenBoundTest.concat(i, parents);
            Matrix cov = ZhangShenBoundTest.getCov(ZhangShenBoundTest.getRows(i, parents, data, calculateRowSubsets), all, all, data, covariances);
            int[] pp = ZhangShenBoundTest.indexedParents(parents);
            Matrix covxx = cov.getSelection(pp, pp);
            Matrix covxy = cov.getSelection(pp, new int[]{0});
            Matrix b = covxx.inverse().times(covxy);
            Matrix bStar = ZhangShenBoundTest.bStar(b);
            return bStar.transpose().times(cov).times(bStar).get(0, 0);
        }
        catch (SingularMatrixException e) {
            List<Node> variables = covariances.getVariables();
            ArrayList<Node> p = new ArrayList<Node>();
            for (int _p : parents) {
                p.add(variables.get(_p));
            }
            System.out.println("Singularity " + variables.get(i) + " | " + p);
            return Double.NEGATIVE_INFINITY;
        }
    }

    private double getLambda(int m) {
        if (this.lambdas == null) {
            this.lambdas = new ArrayList<Double>();
        }
        if (this.lambdas.size() - 1 < m) {
            for (int t = this.lambdas.size(); t <= m; ++t) {
                double lambda = ZhangShenBoundTest.zhangShenLambda(this.variables.size(), t, this.riskBound);
                this.lambdas.add(lambda);
            }
        }
        return this.lambdas.get(m);
    }

    public static double zhangShenLambda(int pn, int m0, double riskBound) {
        if (pn == m0) {
            throw new IllegalArgumentException("m0 should not equal pn");
        }
        double high = 1000000.0;
        double low = 0.0;
        while (high - low > 1.0E-10) {
            double lambda = (high + low) / 2.0;
            double p = ZhangShenBoundTest.getP(pn, m0, lambda);
            if (p < 1.0 - riskBound) {
                low = lambda;
                continue;
            }
            high = lambda;
        }
        return (high + low) / 2.0;
    }

    public static double getP(int pn, int m0, double lambda) {
        return 2.0 - FastMath.pow(1.0 + FastMath.exp(-(lambda - 1.0) / 2.0) * FastMath.sqrt(lambda), pn - m0);
    }

    private static double getSquaredEucleanNorm(int i, int[] parents, Matrix data) {
        int[] rows = new int[data.rows()];
        for (int t = 0; t < rows.length; ++t) {
            rows[t] = t;
        }
        Matrix y = data.getSelection(rows, new int[]{i});
        Matrix x = data.getSelection(rows, parents);
        Matrix xT = x.transpose();
        Matrix xTx = xT.times(x);
        Matrix xTxInv = xTx.inverse();
        Matrix xTy = xT.times(y);
        Matrix b = xTxInv.times(xTy);
        Matrix yhat = x.times(b);
        double sum = 0.0;
        for (int q = 0; q < data.rows(); ++q) {
            double diff = data.get(q, i) - yhat.get(q, 0);
            sum += diff * diff;
        }
        return sum;
    }

    @NotNull
    public static Matrix bStar(Matrix b) {
        Matrix byx = new Matrix(b.rows() + 1, 1);
        byx.set(0, 0, 1.0);
        for (int j = 0; j < b.rows(); ++j) {
            byx.set(j + 1, 0, -b.get(j, 0));
        }
        return byx;
    }

    @Override
    public double localScore(int i, int parent) {
        return this.localScore(i, new int[]{parent});
    }

    @Override
    public double localScore(int i) {
        return this.localScore(i, new int[0]);
    }

    public ICovarianceMatrix getCovariances() {
        return this.covariances;
    }

    @Override
    public int getSampleSize() {
        return this.sampleSize;
    }

    @Override
    public boolean isEffectEdge(double bump) {
        return bump > 0.0;
    }

    public boolean isVerbose() {
        return this.verbose;
    }

    public void setVerbose(boolean verbose) {
        this.verbose = verbose;
    }

    @Override
    public List<Node> getVariables() {
        return this.variables;
    }

    @Override
    public Node getVariable(String targetName) {
        for (Node node : this.variables) {
            if (!node.getName().equals(targetName)) continue;
            return node;
        }
        return null;
    }

    @Override
    public int getMaxDegree() {
        return (int)FastMath.ceil(FastMath.log(this.sampleSize));
    }

    @Override
    public boolean determines(List<Node> z, Node y) {
        int i = this.variables.indexOf(y);
        int[] k = this.indices(z);
        double v = this.localScore(i, k);
        return Double.isNaN(v);
    }

    private void setCovariances(ICovarianceMatrix covariances) {
        CorrelationMatrix correlations = new CorrelationMatrix(covariances);
        this.covariances = covariances;
        boolean exists = false;
        for (int i = 0; i < correlations.getSize(); ++i) {
            for (int j = 0; j < correlations.getSize(); ++j) {
                double r;
                if (i == j || !(FastMath.abs(r = correlations.getValue(i, j)) > this.correlationThreshold)) continue;
                System.out.println("Absolute correlation too high: " + r);
                exists = true;
            }
        }
        if (exists) {
            throw new IllegalArgumentException("Some correlations are too high (> " + this.correlationThreshold + ") in absolute value.");
        }
        this.N = covariances.getSampleSize();
    }

    private static int[] append(int[] z, int x) {
        int[] _z = Arrays.copyOf(z, z.length + 1);
        _z[z.length] = x;
        return _z;
    }

    private static int[] indexedParents(int[] parents) {
        int[] pp = new int[parents.length];
        for (int j = 0; j < pp.length; ++j) {
            pp[j] = j + 1;
        }
        return pp;
    }

    private static int[] concat(int i, int[] parents) {
        int[] all = new int[parents.length + 1];
        all[0] = i;
        System.arraycopy(parents, 0, all, 1, parents.length);
        return all;
    }

    private static Matrix getCov(List<Integer> rows, int[] _rows, int[] cols, Matrix data, ICovarianceMatrix covarianceMatrix) {
        if (rows == null) {
            return covarianceMatrix.getSelection(_rows, cols);
        }
        Matrix cov = new Matrix(_rows.length, cols.length);
        for (int i = 0; i < _rows.length; ++i) {
            for (int j = 0; j < cols.length; ++j) {
                double mui = 0.0;
                double muj = 0.0;
                for (int k : rows) {
                    mui += data.get(k, _rows[i]);
                    muj += data.get(k, cols[j]);
                }
                mui /= (double)(rows.size() - 1);
                muj /= (double)(rows.size() - 1);
                double _cov = 0.0;
                for (int k : rows) {
                    _cov += (data.get(k, _rows[i]) - mui) * (data.get(k, cols[j]) - muj);
                }
                double mean = _cov / (double)rows.size();
                cov.set(i, j, mean);
            }
        }
        return cov;
    }

    private static List<Integer> getRows(int i, int[] parents, Matrix data, boolean calculateRowSubsets) {
        if (!calculateRowSubsets) {
            return null;
        }
        ArrayList<Integer> rows = new ArrayList<Integer>();
        block0: for (int k = 0; k < data.rows(); ++k) {
            if (Double.isNaN(data.get(k, i))) continue;
            for (int p : parents) {
                if (Double.isNaN(data.get(k, p))) continue block0;
            }
            rows.add(k);
        }
        return rows;
    }

    public void setCalculateSquaredEuclideanNorms(boolean calculateSquaredEuclideanNorms) {
        this.calculateSquaredEuclideanNorms = calculateSquaredEuclideanNorms;
    }

    public boolean isChanged() {
        return this.changed;
    }

    public void setChanged(boolean b) {
        this.changed = b;
    }

    public void setPenaltyDiscount(double penaltyDiscount) {
        this.penaltyDiscount = penaltyDiscount;
    }

    public double getPenaltyDiscount() {
        return this.penaltyDiscount;
    }

    public void setRiskBound(double riskBound) {
        this.riskBound = riskBound;
    }

    public void setCorrelationThreshold(double correlationThreshold) {
        this.correlationThreshold = correlationThreshold;
    }

    public void setTakeLog(boolean takeLog) {
        this.takeLog = takeLog;
    }

    public void setTrueErrorVariance(double trueErrorVariance) {
        this.trueErrorVariance = trueErrorVariance;
    }
}

