/*
 * Decompiled with CFR 0.152.
 */
package edu.cmu.tetrad.search.test;

import edu.cmu.tetrad.data.DataModel;
import edu.cmu.tetrad.data.DataSet;
import edu.cmu.tetrad.data.DataTransforms;
import edu.cmu.tetrad.data.ICovarianceMatrix;
import edu.cmu.tetrad.graph.IndependenceFact;
import edu.cmu.tetrad.graph.Node;
import edu.cmu.tetrad.search.IndependenceTest;
import edu.cmu.tetrad.search.test.IndependenceResult;
import edu.cmu.tetrad.search.utils.LogUtilsSearch;
import edu.cmu.tetrad.util.Matrix;
import edu.cmu.tetrad.util.StatUtils;
import edu.cmu.tetrad.util.TetradLogger;
import edu.cmu.tetrad.util.Vector;
import edu.pitt.csb.mgm.EigenDecomposition;
import java.text.DecimalFormat;
import java.util.ArrayList;
import java.util.Collections;
import java.util.HashMap;
import java.util.Iterator;
import java.util.LinkedList;
import java.util.List;
import java.util.Map;
import java.util.Set;
import java.util.concurrent.ConcurrentHashMap;
import org.apache.commons.math3.distribution.GammaDistribution;
import org.apache.commons.math3.distribution.NormalDistribution;
import org.apache.commons.math3.linear.BlockRealMatrix;
import org.apache.commons.math3.linear.RealVector;
import org.apache.commons.math3.linear.SingularMatrixException;
import org.apache.commons.math3.random.SynchronizedRandomGenerator;
import org.apache.commons.math3.random.Well44497b;
import org.apache.commons.math3.util.FastMath;

public class Kci
implements IndependenceTest {
    private final DataSet data;
    private final List<Node> variables;
    private final double[] h;
    private final NormalDistribution normal = new NormalDistribution(new SynchronizedRandomGenerator(new Well44497b(193924L)), 0.0, 1.0);
    private final Map<Node, Integer> hash;
    private final Map<IndependenceFact, IndependenceResult> facts = new ConcurrentHashMap<IndependenceFact, IndependenceResult>();
    private double alpha;
    private boolean approximate;
    private double threshold = 0.01;
    private int numBootstraps = 5000;
    private double widthMultiplier = 1.0;
    private double epsilon = 0.001;
    private boolean verbose;

    public Kci(DataSet data, double alpha) {
        this.data = DataTransforms.standardizeData(data);
        this.variables = data.getVariables();
        int n = this.data.getNumRows();
        this.hash = new HashMap<Node, Integer>();
        for (int i = 0; i < this.variables.size(); ++i) {
            this.hash.put(this.variables.get(i), i);
        }
        double[][] dataCols = this.data.getDoubleData().transpose().toArray();
        this.h = new double[this.variables.size()];
        for (int i = 0; i < this.data.getNumColumns(); ++i) {
            this.h[i] = this.h(this.variables.get(i), dataCols, this.hash);
        }
        Matrix Ones = new Matrix(n, 1);
        for (int j = 0; j < n; ++j) {
            Ones.set(j, 0, 1.0);
        }
        this.alpha = alpha;
    }

    @Override
    public IndependenceTest indTestSubset(List<Node> vars) {
        throw new UnsupportedOperationException("Method not implemented.");
    }

    @Override
    public IndependenceResult checkIndependence(Node x, Node y, Set<Node> z) {
        try {
            if (Thread.currentThread().isInterrupted()) {
                return new IndependenceResult(new IndependenceFact(x, y, z), true, Double.NaN, Double.NaN);
            }
            ArrayList<Node> allVars = new ArrayList<Node>();
            allVars.add(x);
            allVars.add(y);
            allVars.addAll(z);
            IndependenceFact fact = new IndependenceFact(x, y, z);
            if (this.facts.containsKey(fact)) {
                IndependenceResult result = this.facts.get(fact);
                if (this.verbose) {
                    double p = result.getPValue();
                    if (result.isIndependent()) {
                        TetradLogger.getInstance().forceLogMessage(fact + " INDEPENDENT p = " + p);
                    } else {
                        TetradLogger.getInstance().forceLogMessage(fact + " dependent p = " + p);
                    }
                }
                return new IndependenceResult(fact, result.isIndependent(), result.getPValue(), this.getAlpha() - result.getPValue());
            }
            List<Integer> rows = this.getRows(this.data);
            int[] _cols = new int[allVars.size()];
            for (int i = 0; i < allVars.size(); ++i) {
                Node key = (Node)allVars.get(i);
                _cols[i] = this.hash.get(key);
            }
            int[] _rows = new int[rows.size()];
            for (int i = 0; i < rows.size(); ++i) {
                _rows[i] = rows.get(i);
            }
            DataSet data = this.data.subsetRowsColumns(_rows, _cols);
            double[][] _data = data.getDoubleData().transpose().toArray();
            HashMap<Node, Integer> hash = new HashMap<Node, Integer>();
            for (int i = 0; i < allVars.size(); ++i) {
                hash.put((Node)allVars.get(i), i);
            }
            int N = data.getNumRows();
            Matrix ones = new Matrix(N, 1);
            for (int j = 0; j < N; ++j) {
                ones.set(j, 0, 1.0);
            }
            Matrix I = Matrix.identity(N);
            Matrix H = I.minus(ones.times(ones.transpose()).scalarMult(1.0 / (double)N));
            double[] h = new double[allVars.size()];
            int count = 0;
            double sum = 0.0;
            for (int i = 0; i < allVars.size(); ++i) {
                h[i] = this.h[this.hash.get(allVars.get(i))];
                if (h[i] == 0.0) continue;
                sum += h[i];
                ++count;
            }
            double avg = sum / (double)count;
            for (int i = 0; i < h.length; ++i) {
                if (h[i] != 0.0) continue;
                h[i] = avg;
            }
            IndependenceResult result = this.facts.get(fact);
            if (this.facts.get(fact) != null) {
                return new IndependenceResult(fact, result.isIndependent(), result.getPValue(), this.getAlpha() - result.getPValue());
            }
            result = z.isEmpty() ? this.isIndependentUnconditional(x, y, fact, _data, h, N, hash) : this.isIndependentConditional(x, y, z, fact, _data, N, H, I, h, hash);
            if (this.verbose) {
                double p = result.getPValue();
                if (result.isIndependent()) {
                    TetradLogger.getInstance().forceLogMessage(fact + " INDEPENDENT p = " + p);
                } else {
                    TetradLogger.getInstance().forceLogMessage(fact + " dependent p = " + p);
                }
            }
            return new IndependenceResult(fact, result.isIndependent(), result.getPValue(), this.getAlpha() - result.getPValue());
        }
        catch (SingularMatrixException e) {
            throw new RuntimeException("Singularity encountered when testing " + LogUtilsSearch.independenceFact(x, y, z));
        }
    }

    @Override
    public List<Node> getVariables() {
        return this.variables;
    }

    @Override
    public Node getVariable(String name) {
        return this.data.getVariable(name);
    }

    public boolean determines(List<Node> z, Node y) {
        throw new UnsupportedOperationException();
    }

    @Override
    public double getAlpha() {
        return this.alpha;
    }

    @Override
    public void setAlpha(double alpha) {
        this.alpha = alpha;
    }

    @Override
    public String toString() {
        return "KCI, alpha = " + new DecimalFormat("0.0###").format(this.getAlpha());
    }

    @Override
    public DataModel getData() {
        return this.data;
    }

    @Override
    public ICovarianceMatrix getCov() {
        throw new UnsupportedOperationException("Method not implemented.");
    }

    @Override
    public List<DataSet> getDataSets() {
        LinkedList<DataSet> L = new LinkedList<DataSet>();
        L.add(this.data);
        return L;
    }

    @Override
    public int getSampleSize() {
        return this.data.getNumRows();
    }

    public double getScore(IndependenceResult result) {
        return this.getAlpha() - result.getPValue();
    }

    public void setApproximate(boolean approximate) {
        this.approximate = approximate;
    }

    public void setWidthMultiplier(double widthMultiplier) {
        if (widthMultiplier <= 0.0) {
            throw new IllegalStateException("Width must be > 0");
        }
        this.widthMultiplier = widthMultiplier;
    }

    public void setNumBootstraps(int numBootstraps) {
        if (numBootstraps < 1) {
            throw new IllegalArgumentException("Num bootstraps should be >= 1: " + numBootstraps);
        }
        this.numBootstraps = numBootstraps;
    }

    public void setThreshold(double threshold) {
        if (threshold < 0.0) {
            throw new IllegalArgumentException("Threshold must be >= 0.0: " + threshold);
        }
        this.threshold = threshold;
    }

    public void setEpsilon(double epsilon) {
        this.epsilon = epsilon;
    }

    @Override
    public boolean isVerbose() {
        return this.verbose;
    }

    @Override
    public void setVerbose(boolean verbose) {
        this.verbose = verbose;
    }

    private IndependenceResult isIndependentUnconditional(Node x, Node y, IndependenceFact fact, double[][] _data, double[] _h, int N, Map<Node, Integer> hash) {
        Matrix Ones = new Matrix(N, 1);
        for (int j = 0; j < N; ++j) {
            Ones.set(j, 0, 1.0);
        }
        Matrix H = Matrix.identity(N).minus(Ones.times(Ones.transpose()).scalarMult(1.0 / (double)N));
        Matrix kx = this.center(this.kernelMatrix(_data, x, null, this.widthMultiplier, hash, N, _h), H);
        Matrix ky = this.center(this.kernelMatrix(_data, y, null, this.widthMultiplier, hash, N, _h), H);
        try {
            if (this.approximate) {
                double theta_appr;
                double var_appr;
                double sta = kx.times(ky).trace();
                double mean_appr = kx.trace() * ky.trace() / (double)N;
                double k_appr = mean_appr * mean_appr / (var_appr = 2.0 * kx.times(kx).trace() * ky.times(ky).trace() / (double)(N * N));
                double p = 1.0 - new GammaDistribution(k_appr, theta_appr = var_appr / mean_appr).cumulativeProbability(sta);
                boolean indep = p > this.getAlpha();
                IndependenceResult result = new IndependenceResult(fact, indep, p, this.getAlpha() - p);
                this.facts.put(fact, result);
                return result;
            }
            return this.theorem4(kx, ky, fact, N);
        }
        catch (Exception e) {
            e.printStackTrace();
            IndependenceResult result = new IndependenceResult(fact, false, 0.0, this.getAlpha());
            this.facts.put(fact, result);
            return result;
        }
    }

    private IndependenceResult isIndependentConditional(Node x, Node y, Set<Node> _z, IndependenceFact fact, double[][] _data, int N, Matrix H, Matrix I, double[] _h, Map<Node, Integer> hash) {
        ArrayList<Node> z = new ArrayList<Node>(_z);
        Collections.sort(z);
        try {
            Matrix KXZ = this.center(this.kernelMatrix(_data, x, z, this.widthMultiplier, hash, N, _h), H);
            Matrix Ky = this.center(this.kernelMatrix(_data, y, null, this.widthMultiplier, hash, N, _h), H);
            Matrix KZ = this.center(this.kernelMatrix(_data, null, z, this.widthMultiplier, hash, N, _h), H);
            Matrix Rz = KZ.plus(I.scalarMult(this.epsilon)).inverse().scalarMult(this.epsilon);
            Matrix kx = this.symmetrized(Rz.times(KXZ).times(Rz.transpose()));
            Matrix ky = this.symmetrized(Rz.times(Ky).times(Rz.transpose()));
            return this.proposition5(kx, ky, fact, N);
        }
        catch (Exception e) {
            e.printStackTrace();
            boolean indep = false;
            IndependenceResult result = new IndependenceResult(fact, indep, 0.0, this.getAlpha());
            this.facts.put(fact, result);
            return result;
        }
    }

    private IndependenceResult theorem4(Matrix kx, Matrix ky, IndependenceFact fact, int N) {
        double T = 1.0 / (double)N * kx.times(ky).trace();
        Eigendecomposition eigendecompositionx = new Eigendecomposition(kx).invoke();
        List<Double> evx = eigendecompositionx.getTopEigenvalues();
        Eigendecomposition eigendecompositiony = new Eigendecomposition(ky).invoke();
        List<Double> evy = eigendecompositiony.getTopEigenvalues();
        int sum = 0;
        for (int j = 0; j < this.numBootstraps; ++j) {
            double tui = 0.0;
            for (double lambdax : evx) {
                for (double lambday : evy) {
                    tui += lambdax * lambday * this.getChisqSample();
                }
            }
            if (!((tui /= (double)(N * N)) > T)) continue;
            ++sum;
        }
        double p = (double)sum / (double)this.numBootstraps;
        boolean indep = p > this.getAlpha();
        IndependenceResult result = new IndependenceResult(fact, indep, p, this.getAlpha() - p);
        this.facts.put(fact, result);
        return result;
    }

    private IndependenceResult proposition5(Matrix kx, Matrix ky, IndependenceFact fact, int N) {
        Matrix uuprod;
        double T = 1.0 / (double)N * kx.times(ky).trace();
        Eigendecomposition eigendecompositionx = new Eigendecomposition(kx).invoke();
        Matrix vx = eigendecompositionx.getV();
        Matrix dx = eigendecompositionx.getD();
        Eigendecomposition eigendecompositiony = new Eigendecomposition(ky).invoke();
        Matrix vy = eigendecompositiony.getV();
        Matrix dy = eigendecompositiony.getD();
        Matrix vdx = vx.times(dx);
        Matrix vdy = vy.times(dy);
        int prod = vx.getNumColumns() * vy.getNumColumns();
        Matrix UU = new Matrix(N, prod);
        for (int i = 0; i < vx.getNumColumns(); ++i) {
            for (int j = 0; j < vy.getNumColumns(); ++j) {
                for (int k = 0; k < N; ++k) {
                    UU.set(k, i * dy.getNumColumns() + j, vdx.get(k, i) * vdy.get(k, j));
                }
            }
        }
        Matrix matrix = uuprod = prod > N ? UU.times(UU.transpose()) : UU.transpose().times(UU);
        if (this.approximate) {
            double theta_appr;
            double var_appr;
            double sta = kx.times(ky).trace();
            double mean_appr = uuprod.trace();
            double k_appr = mean_appr * mean_appr / (var_appr = 2.0 * uuprod.times(uuprod).trace());
            double p = 1.0 - new GammaDistribution(k_appr, theta_appr = var_appr / mean_appr).cumulativeProbability(sta);
            boolean indep = p > this.getAlpha();
            IndependenceResult result = new IndependenceResult(fact, indep, p, this.getAlpha() - p);
            this.facts.put(fact, result);
            return result;
        }
        Eigendecomposition eigendecompositionu = new Eigendecomposition(uuprod).invoke();
        List<Double> eigenu = eigendecompositionu.getTopEigenvalues();
        int sum = 0;
        for (int j = 0; j < this.numBootstraps; ++j) {
            double s = 0.0;
            for (double lambdaStar : eigenu) {
                s += lambdaStar * this.getChisqSample();
            }
            if (!((s *= 1.0 / (double)N) > T)) continue;
            ++sum;
        }
        double p = (double)sum / (double)this.numBootstraps;
        boolean indep = p > this.getAlpha();
        IndependenceResult result = new IndependenceResult(fact, indep, p, this.getAlpha() - p);
        this.facts.put(fact, result);
        return result;
    }

    private List<Integer> series(int size) {
        ArrayList<Integer> series = new ArrayList<Integer>();
        for (int i = 0; i < size; ++i) {
            series.add(i);
        }
        return series;
    }

    private Matrix center(Matrix K, Matrix H) {
        return H.times(K).times(H);
    }

    private double getChisqSample() {
        double z = this.normal.sample();
        return z * z;
    }

    private double h(Node x, double[][] _data, Map<Node, Integer> hash) {
        double[] xCol = _data[hash.get(x)];
        double[] g = new double[xCol.length];
        double median = StatUtils.median(xCol);
        for (int j = 0; j < xCol.length; ++j) {
            g[j] = FastMath.abs(xCol[j] - median);
        }
        double mad = StatUtils.median(g);
        return 1.4826 * mad * FastMath.pow(1.3333333333333333 / (double)xCol.length, 0.2);
    }

    private List<Integer> getTopIndices(double[] prod, List<Integer> allIndices, double threshold) {
        double maxEig = prod[allIndices.get(0)];
        ArrayList<Integer> indices = new ArrayList<Integer>();
        for (int i : allIndices) {
            if (!(prod[i] > maxEig * threshold)) continue;
            indices.add(i);
        }
        return indices;
    }

    private Matrix symmetrized(Matrix kx) {
        return kx.plus(kx.transpose()).scalarMult(0.5);
    }

    private Matrix kernelMatrix(double[][] _data, Node x, List<Node> z, double widthMultiplier, Map<Node, Integer> hash, int N, double[] _h) {
        ArrayList<Integer> _z = new ArrayList<Integer>();
        if (x != null) {
            _z.add(hash.get(x));
        }
        if (z != null) {
            for (Node z2 : z) {
                _z.add(hash.get(z2));
            }
        }
        double h = this.getH(_z, _h);
        Matrix result = new Matrix(N, N);
        for (int i = 0; i < N; ++i) {
            for (int j = i + 1; j < N; ++j) {
                double d = this.distance(_data, _z, i, j);
                double k = this.kernelGaussian(d, widthMultiplier * h);
                result.set(i, j, k);
                result.set(j, i, k);
            }
        }
        double k = this.kernelGaussian(0.0, widthMultiplier * h);
        for (int i = 0; i < N; ++i) {
            result.set(i, i, k);
        }
        return result;
    }

    private double getH(List<Integer> _z, double[] _h) {
        double h = 0.0;
        for (int c : _z) {
            if (!(_h[c] > h)) continue;
            h = _h[c];
        }
        return h *= FastMath.sqrt(_z.size());
    }

    private double kernelGaussian(double z, double width) {
        return FastMath.exp(-(z /= width));
    }

    private double distance(double[][] data, List<Integer> cols, int i, int j) {
        double sum = 0.0;
        for (int col : cols) {
            double d = data[col][i] - data[col][j];
            sum += d * d;
        }
        return sum;
    }

    private List<Integer> getRows(DataSet dataSet) {
        ArrayList<Integer> rows = new ArrayList<Integer>();
        for (int k = 0; k < dataSet.getNumRows(); ++k) {
            rows.add(k);
        }
        return rows;
    }

    private class Eigendecomposition {
        private final Matrix k;
        private Matrix D;
        private Matrix V;
        private List<Double> topEigenvalues;

        public Eigendecomposition(Matrix k) {
            if (k.getNumRows() == 0 || k.getNumColumns() == 0) {
                throw new IllegalArgumentException("Empty matrix to decompose. Please don't do that to me.");
            }
            this.k = k;
        }

        public Matrix getD() {
            return this.D;
        }

        public Matrix getV() {
            return this.V;
        }

        public List<Double> getTopEigenvalues() {
            return this.topEigenvalues;
        }

        public Eigendecomposition invoke() {
            EigenDecomposition ed = new EigenDecomposition(new BlockRealMatrix(this.k.toArray()));
            double[] arr = ed.getRealEigenvalues();
            List indx = Kci.this.series(arr.length);
            List topIndices = Kci.this.getTopIndices(arr, indx, Kci.this.threshold);
            this.D = new Matrix(topIndices.size(), topIndices.size());
            for (int i = 0; i < topIndices.size(); ++i) {
                this.D.set(i, i, FastMath.sqrt(arr[(Integer)topIndices.get(i)]));
            }
            this.topEigenvalues = new ArrayList<Double>();
            Iterator i = topIndices.iterator();
            while (i.hasNext()) {
                int t = (Integer)i.next();
                this.getTopEigenvalues().add(arr[t]);
            }
            this.V = new Matrix(ed.getEigenvector(0).getDimension(), topIndices.size());
            for (int i2 = 0; i2 < topIndices.size(); ++i2) {
                RealVector t = ed.getEigenvector((Integer)topIndices.get(i2));
                this.V.assignColumn(i2, new Vector(t.toArray()));
            }
            return this;
        }
    }
}

