/*
 * Decompiled with CFR 0.152.
 */
package edu.cmu.tetrad.search;

import edu.cmu.tetrad.cluster.KMeans;
import edu.cmu.tetrad.data.BoxDataSet;
import edu.cmu.tetrad.data.DataSet;
import edu.cmu.tetrad.data.DoubleDataBox;
import edu.cmu.tetrad.data.ICovarianceMatrix;
import edu.cmu.tetrad.graph.IndependenceFact;
import edu.cmu.tetrad.graph.Node;
import edu.cmu.tetrad.search.IndependenceResult;
import edu.cmu.tetrad.search.IndependenceTest;
import edu.cmu.tetrad.search.SearchLogUtils;
import edu.cmu.tetrad.search.kernel.Kernel;
import edu.cmu.tetrad.search.kernel.KernelGaussian;
import edu.cmu.tetrad.search.kernel.KernelUtils;
import edu.cmu.tetrad.util.Matrix;
import edu.cmu.tetrad.util.NumberFormatUtil;
import edu.cmu.tetrad.util.RandomUtil;
import edu.cmu.tetrad.util.TetradLogger;
import java.text.NumberFormat;
import java.util.ArrayList;
import java.util.Collections;
import java.util.List;
import org.apache.commons.math3.util.FastMath;

public final class IndTestHsic
implements IndependenceTest {
    private final List<Node> variables;
    private double alpha;
    private double thresh = Double.NaN;
    private static final NumberFormat nf = NumberFormatUtil.getInstance().getNumberFormat();
    private final DataSet dataSet;
    private double pValue = Double.NaN;
    private double regularizer = 1.0E-4;
    private int perms = 100;
    private double useIncompleteCholesky = 1.0E-18;
    private boolean verbose;

    public IndTestHsic(DataSet dataSet, double alpha) {
        if (!dataSet.isContinuous()) {
            throw new IllegalArgumentException("Data set must be continuous.");
        }
        List<Node> nodes = dataSet.getVariables();
        this.variables = Collections.unmodifiableList(nodes);
        this.setAlpha(alpha);
        this.dataSet = dataSet;
    }

    public IndTestHsic(Matrix data, List<Node> variables, double alpha) {
        BoxDataSet dataSet = new BoxDataSet(new DoubleDataBox(data.toArray()), variables);
        this.variables = Collections.unmodifiableList(variables);
        this.setAlpha(alpha);
        this.dataSet = dataSet;
    }

    @Override
    public IndependenceTest indTestSubset(List<Node> vars) {
        if (vars.isEmpty()) {
            throw new IllegalArgumentException("Subset may not be empty.");
        }
        for (Node var : vars) {
            if (this.variables.contains(var)) continue;
            throw new IllegalArgumentException("All vars must be original vars");
        }
        int[] indices = new int[vars.size()];
        for (int i = 0; i < indices.length; ++i) {
            indices[i] = this.variables.indexOf(vars.get(i));
        }
        double alphaNew = this.getAlpha();
        return new IndTestHsic(this.dataSet.subsetColumns(indices), alphaNew);
    }

    /*
     * Could not resolve type clashes
     */
    @Override
    public IndependenceResult checkIndependence(Node y, Node x, List<Node> z) {
        boolean independent;
        Matrix Kx;
        Matrix Ky;
        int m = this.sampleSize();
        KernelGaussian xKernel = new KernelGaussian(1.0);
        KernelGaussian yKernel = new KernelGaussian(1.0);
        ArrayList<Kernel> zKernel = new ArrayList<Kernel>();
        yKernel.setDefaultBw(this.dataSet, y);
        xKernel.setDefaultBw(this.dataSet, x);
        if (!z.isEmpty()) {
            for (Node node : z) {
                KernelGaussian Zi = new KernelGaussian(1.0);
                Zi.setDefaultBw(this.dataSet, node);
                zKernel.add(Zi);
            }
        }
        Matrix Kz = null;
        if (this.useIncompleteCholesky > 0.0) {
            Ky = KernelUtils.incompleteCholeskyGramMatrix(Collections.singletonList(yKernel), this.dataSet, Collections.singletonList(y), this.useIncompleteCholesky);
            Kx = KernelUtils.incompleteCholeskyGramMatrix(Collections.singletonList(xKernel), this.dataSet, Collections.singletonList(x), this.useIncompleteCholesky);
            if (!z.isEmpty()) {
                Kz = KernelUtils.incompleteCholeskyGramMatrix(zKernel, this.dataSet, z, this.useIncompleteCholesky);
            }
        } else {
            Ky = KernelUtils.constructCentralizedGramMatrix(Collections.singletonList(yKernel), this.dataSet, Collections.singletonList(y));
            Kx = KernelUtils.constructCentralizedGramMatrix(Collections.singletonList(xKernel), this.dataSet, Collections.singletonList(x));
            if (!z.isEmpty()) {
                Kz = KernelUtils.constructCentralizedGramMatrix(zKernel, this.dataSet, z);
            }
        }
        double hsic = z.isEmpty() ? (this.useIncompleteCholesky > 0.0 ? this.empiricalHSICincompleteCholesky(Ky, Kx, m) : this.empiricalHSIC(Ky, Kx, m)) : (this.useIncompleteCholesky > 0.0 ? this.empiricalHSICincompleteCholesky(Ky, Kx, Kz, m) : this.empiricalHSIC(Ky, Kx, Kz, m));
        double[] nullapprox = new double[this.perms];
        int[] zind = null;
        int ycol = this.dataSet.getColumn(y);
        List<List<Integer>> clusterAssign = null;
        if (!z.isEmpty()) {
            KMeans kmeans = KMeans.randomClusters(m / 3);
            zind = new int[z.size()];
            for (int j = 0; j < z.size(); ++j) {
                zind[j] = this.dataSet.getColumn(z.get(j));
            }
            kmeans.cluster(this.dataSet.subsetColumns(z).getDoubleData());
            clusterAssign = kmeans.getClusters();
        }
        for (int i = 0; i < this.perms; ++i) {
            DataSet shuffleData = this.dataSet.copy();
            if (z.isEmpty()) {
                int j;
                ArrayList indicesList = new ArrayList();
                for (j = 0; j < m; ++j) {
                    indicesList.add(j);
                }
                RandomUtil.shuffle(indicesList);
                for (j = 0; j < m; ++j) {
                    double shuffleVal = this.dataSet.getDouble((Integer)indicesList.get(j), ycol);
                    shuffleData.setDouble(j, ycol, shuffleVal);
                }
            } else {
                assert (clusterAssign != null);
                for (List integers : clusterAssign) {
                    ArrayList shuffleCluster = new ArrayList(integers);
                    RandomUtil.shuffle(shuffleCluster);
                    for (int k = 0; k < shuffleCluster.size(); ++k) {
                        double swapVal = this.dataSet.getDouble((Integer)integers.get(k), ycol);
                        shuffleData.setDouble((Integer)shuffleCluster.get(k), ycol, swapVal);
                        for (int zi = 0; zi < z.size(); ++zi) {
                            swapVal = this.dataSet.getDouble((Integer)integers.get(k), zind[zi]);
                            shuffleData.setDouble((Integer)shuffleCluster.get(k), zind[zi], swapVal);
                        }
                    }
                }
            }
            yKernel.setDefaultBw(shuffleData, y);
            for (int j = 0; j < z.size(); ++j) {
                ((Kernel)zKernel.get(j)).setDefaultBw(shuffleData, z.get(j));
            }
            Matrix Kyn = this.useIncompleteCholesky > 0.0 ? KernelUtils.incompleteCholeskyGramMatrix(Collections.singletonList(yKernel), shuffleData, Collections.singletonList(y), this.useIncompleteCholesky) : KernelUtils.constructCentralizedGramMatrix(Collections.singletonList(yKernel), shuffleData, Collections.singletonList(y));
            if (!z.isEmpty()) {
                if (this.useIncompleteCholesky > 0.0) {
                    KernelUtils.incompleteCholeskyGramMatrix(zKernel, shuffleData, z, this.useIncompleteCholesky);
                } else {
                    KernelUtils.constructCentralizedGramMatrix(zKernel, shuffleData, z);
                }
            }
            if (z.isEmpty()) {
                if (this.useIncompleteCholesky > 0.0) {
                    nullapprox[i] = this.empiricalHSICincompleteCholesky(Kyn, Kx, m);
                    continue;
                }
                nullapprox[i] = this.empiricalHSIC(Kyn, Kx, m);
                continue;
            }
            if (this.useIncompleteCholesky > 0.0) {
                assert (Kz != null);
                nullapprox[i] = this.empiricalHSICincompleteCholesky(Kyn, Kx, Kz, m);
                continue;
            }
            nullapprox[i] = this.empiricalHSIC(Kyn, Kx, Kz, m);
        }
        double evalCdf = 0.0;
        for (int i = 0; i < this.perms; ++i) {
            if (!(nullapprox[i] <= hsic)) continue;
            evalCdf += 1.0;
        }
        this.pValue = 1.0 - (evalCdf /= (double)this.perms);
        boolean bl = independent = this.pValue <= this.alpha;
        if (this.verbose && independent) {
            TetradLogger.getInstance().forceLogMessage(SearchLogUtils.independenceFactMsg(x, y, z, this.getPValue()));
        }
        return new IndependenceResult(new IndependenceFact(x, y, z), independent, this.pValue);
    }

    public double empiricalHSIC(Matrix Ky, Matrix Kx, int m) {
        Matrix Kyx = Ky.times(Kx);
        double empHSIC = 0.0;
        for (int i = 0; i < m; ++i) {
            empHSIC += Kyx.get(i, i);
        }
        return empHSIC /= FastMath.pow((double)(m - 1), 2);
    }

    public double empiricalHSICincompleteCholesky(Matrix Gy, Matrix Gx, int m) {
        Matrix H = KernelUtils.constructH(m);
        Matrix Gcy = H.times(Gy);
        Matrix Gcx = H.times(Gx);
        Matrix Gcyt = Gcy.transpose();
        Matrix A = Gcyt.times(Gcx);
        Matrix B = Gcy.times(A);
        Matrix Gcxt = Gcx.transpose();
        double empHSIC = 0.0;
        for (int i = 0; i < m; ++i) {
            empHSIC += this.matrixProductEntry(B, Gcxt, i, i);
        }
        return empHSIC /= FastMath.pow((double)(m - 1), 2);
    }

    private double empiricalHSIC(Matrix Ky, Matrix Kx, Matrix Kz, int m) {
        Matrix Kyx = Ky.times(Kx);
        Matrix Kyz = Ky.times(Kz);
        Matrix Kzx = Kz.times(Kx);
        Matrix Kzreg = Kz.copy();
        for (int i = 0; i < m; ++i) {
            double ent = Kzreg.get(i, i) + this.regularizer;
            Kzreg.set(i, i, ent);
        }
        Matrix A = Kzreg.inverse();
        Kzreg = A.times(A);
        A = Kyz.times(Kzreg);
        Matrix Kyzzregzx = A.times(Kzx);
        Matrix Kyzzregzxzzregz = Kyzzregzx.times(Kz);
        A = Kyzzregzxzzregz.times(Kzreg);
        Kyzzregzxzzregz = A.times(Kz);
        double empHSIC = 0.0;
        for (int i = 0; i < m; ++i) {
            empHSIC += Kyx.get(i, i);
            empHSIC += -2.0 * Kyzzregzx.get(i, i);
            empHSIC += Kyzzregzxzzregz.get(i, i);
        }
        empHSIC /= FastMath.pow((double)(m - 1), 2);
        double Bz = 0.0;
        for (int i = 0; i < m - 1; ++i) {
            for (int j = i + 1; j < m; ++j) {
                Bz += FastMath.pow(Kz.get(i, j), 2);
                Bz += FastMath.pow(Kz.get(j, i), 2);
            }
        }
        Bz = (double)(m * (m - 1)) / Bz;
        return empHSIC *= Bz;
    }

    public double empiricalHSICincompleteCholesky(Matrix Gy, Matrix Gx, Matrix Gz, int m) {
        int kz = Gz.columns();
        Matrix H = KernelUtils.constructH(m);
        Matrix Gcy = H.times(Gy);
        Matrix Gcx = H.times(Gx);
        Matrix Gcz = H.times(Gz);
        Matrix Gcyt = Gcy.transpose();
        Matrix A = Gcyt.times(Gcx);
        Matrix B = Gcy.times(A);
        Matrix Gcxt = Gcx.transpose();
        double empHSIC = 0.0;
        for (int i = 0; i < m; ++i) {
            empHSIC += this.matrixProductEntry(B, Gcxt, i, i);
        }
        Matrix Gytz = Gcyt.times(Gcz);
        Matrix Gczt = Gcz.transpose();
        Matrix Gztx = Gczt.times(Gcx);
        Matrix Gztz = Gczt.times(Gcz);
        Matrix Gztzr = Gztz.copy();
        for (int i = 0; i < kz; ++i) {
            Gztzr.set(i, i, Gztz.get(i, i) + this.regularizer);
        }
        Matrix ZI = Gztzr.inverse();
        Matrix ZIzt = ZI.times(Gczt);
        Matrix Gzr = Gcz.copy();
        for (int i = 0; i < m; ++i) {
            for (int j = 0; j < kz; ++j) {
                Gzr.set(i, j, Gcz.get(i, j) * (-1.0 / this.regularizer));
            }
        }
        Matrix Zinv = Gzr.times(ZIzt);
        for (int i = 0; i < m; ++i) {
            Zinv.set(i, i, Zinv.get(i, i) + 1.0 / this.regularizer);
        }
        Matrix Gztzinv = Gczt.times(Zinv);
        Matrix Gzinvz = Zinv.times(Gcz);
        Matrix Gztinv2z = Gztzinv.times(Gzinvz);
        Matrix Gytzztzinv2z = Gytz.times(Gztinv2z);
        Matrix Gytzztzinv2zztx = Gytzztzinv2z.times(Gztx);
        Matrix Gyytzztzinv2zztx = Gcy.times(Gytzztzinv2zztx);
        double second = 0.0;
        for (int i = 0; i < m; ++i) {
            second += this.matrixProductEntry(Gyytzztzinv2zztx, Gcxt, i, i);
        }
        empHSIC -= 2.0 * second;
        Matrix Gxtz = Gcxt.times(Gcz);
        Matrix Gxtzztinv2z = Gxtz.times(Gztinv2z);
        Matrix Gyytzztzinv2zztxxtzztinv2z = Gyytzztzinv2zztx.times(Gxtzztinv2z);
        for (int i = 0; i < m; ++i) {
            empHSIC += this.matrixProductEntry(Gyytzztzinv2zztxxtzztinv2z, Gczt, i, i);
        }
        double betaz = 0.0;
        for (int i = 0; i < m - 1; ++i) {
            for (int j = i + 1; j < m; ++j) {
                betaz += FastMath.pow(this.matrixProductEntry(Gcz, Gczt, i, j), 2);
                betaz += FastMath.pow(this.matrixProductEntry(Gcz, Gczt, j, i), 2);
            }
        }
        return empHSIC *= (double)m / (betaz * (double)(m - 1));
    }

    public double getThreshold() {
        return this.thresh;
    }

    public double getPValue() {
        return this.pValue;
    }

    @Override
    public void setAlpha(double alpha) {
        if (alpha < 0.0 || alpha > 1.0) {
            throw new IllegalArgumentException("Significance out of range.");
        }
        this.alpha = alpha;
        this.thresh = Double.NaN;
    }

    public void setIncompleteCholesky(double precision) {
        this.useIncompleteCholesky = precision;
    }

    public void setPerms(int perms) {
        this.perms = perms;
    }

    public void setRegularizer(double regularizer) {
        this.regularizer = regularizer;
    }

    @Override
    public double getAlpha() {
        return this.alpha;
    }

    public double getPrecision() {
        return this.useIncompleteCholesky;
    }

    public int getPerms() {
        return this.perms;
    }

    @Override
    public List<Node> getVariables() {
        return this.variables;
    }

    @Override
    public Node getVariable(String name) {
        for (int i = 0; i < this.getVariables().size(); ++i) {
            Node variable = this.getVariables().get(i);
            if (!variable.getName().equals(name)) continue;
            return variable;
        }
        return null;
    }

    @Override
    public List<String> getVariableNames() {
        List<Node> variables = this.getVariables();
        ArrayList<String> variableNames = new ArrayList<String>();
        for (Node variable1 : variables) {
            variableNames.add(variable1.getName());
        }
        return variableNames;
    }

    @Override
    public DataSet getData() {
        return this.dataSet;
    }

    @Override
    public ICovarianceMatrix getCov() {
        return null;
    }

    @Override
    public List<DataSet> getDataSets() {
        return null;
    }

    @Override
    public int getSampleSize() {
        return 0;
    }

    @Override
    public List<Matrix> getCovMatrices() {
        return null;
    }

    @Override
    public double getScore() {
        return this.getPValue();
    }

    @Override
    public String toString() {
        return "HSIC, alpha = " + nf.format(this.getAlpha());
    }

    @Override
    public boolean determines(List<Node> z, Node x) throws UnsupportedOperationException {
        throw new UnsupportedOperationException("Method not implemented");
    }

    private int sampleSize() {
        return this.dataSet.getNumRows();
    }

    private double matrixProductEntry(Matrix X, Matrix Y, int i, int j) {
        double entry = 0.0;
        for (int k = 0; k < X.columns(); ++k) {
            entry += X.get(i, k) * Y.get(k, j);
        }
        return entry;
    }

    @Override
    public boolean isVerbose() {
        return this.verbose;
    }

    @Override
    public void setVerbose(boolean verbose) {
        this.verbose = verbose;
    }
}

