/*
 * Decompiled with CFR 0.152.
 */
package edu.cmu.tetrad.search;

import edu.cmu.tetrad.data.ColtDataSet;
import edu.cmu.tetrad.data.ContinuousVariable;
import edu.cmu.tetrad.data.DataSet;
import edu.cmu.tetrad.graph.Dag;
import edu.cmu.tetrad.graph.Node;
import edu.cmu.tetrad.search.IndTestFisherZ;
import edu.cmu.tetrad.search.IndTestHsic;
import edu.cmu.tetrad.search.kernel.Kernel;
import edu.cmu.tetrad.search.kernel.KernelGaussian;
import edu.cmu.tetrad.search.kernel.KernelUtils;
import edu.cmu.tetrad.sem.SemIm;
import edu.cmu.tetrad.sem.SemPm;
import edu.cmu.tetrad.util.TetradLogger;
import java.util.ArrayList;
import java.util.Arrays;
import junit.framework.TestCase;
import no.uib.cipr.matrix.DenseMatrix;
import no.uib.cipr.matrix.Matrices;
import no.uib.cipr.matrix.Matrix;
import no.uib.cipr.matrix.UpperSymmPackMatrix;

public class TestIndTestHsic
extends TestCase {
    public TestIndTestHsic(String name) {
        super(name);
    }

    @Override
    public void setUp() throws Exception {
        TetradLogger.getInstance().addOutputStream(System.out);
        TetradLogger.getInstance().setForceLog(true);
        TetradLogger.getInstance().setLogging(true);
    }

    @Override
    public void tearDown() {
        TetradLogger.getInstance().setForceLog(false);
        TetradLogger.getInstance().removeOutputStream(System.out);
    }

    public void testIncompleteCholesky() {
        int m = 500;
        double precision = 1.0E-8;
        ContinuousVariable A = new ContinuousVariable("A");
        ContinuousVariable B = new ContinuousVariable("B");
        ContinuousVariable C = new ContinuousVariable("C");
        Dag dag = new Dag(Arrays.asList(A, B, C));
        dag.addDirectedEdge(A, B);
        dag.addDirectedEdge(C, B);
        SemPm sem = new SemPm(dag);
        SemIm im = new SemIm(sem);
        DataSet data = im.simulateData(m, false);
        ArrayList<Kernel> kernels = new ArrayList<Kernel>();
        for (int i = 0; i < 3; ++i) {
            KernelGaussian k = new KernelGaussian(1.0);
            k.setDefaultBw(data, data.getVariable(i));
            kernels.add(k);
        }
        UpperSymmPackMatrix Kx = KernelUtils.constructGramMatrix(kernels, data, Arrays.asList(data.getVariable(0)));
        Matrix Gx = KernelUtils.incompleteCholeskyGramMatrix(kernels, data, Arrays.asList(data.getVariable(0)), precision);
        int nx = Gx.numColumns();
        DenseMatrix Gxt = new DenseMatrix(nx, m);
        Gx.transpose(Gxt);
        DenseMatrix GGx = new DenseMatrix(m, m);
        Gx.mult(Gxt, GGx);
        System.out.println("X true: " + TestIndTestHsic.trace(Kx, m));
        System.out.println("X appr: " + TestIndTestHsic.trace(GGx, m));
        UpperSymmPackMatrix Ky = KernelUtils.constructGramMatrix(kernels, data, Arrays.asList(data.getVariable(1)));
        Matrix Gy = KernelUtils.incompleteCholeskyGramMatrix(kernels, data, Arrays.asList(data.getVariable(1)), precision);
        int ny = Gy.numColumns();
        DenseMatrix Gyt = new DenseMatrix(ny, m);
        Gy.transpose(Gyt);
        DenseMatrix GGy = new DenseMatrix(m, m);
        Gy.mult(Gyt, GGy);
        System.out.println("Y true: " + TestIndTestHsic.trace(Ky, m));
        System.out.println("Y appr: " + TestIndTestHsic.trace(GGy, m));
        DenseMatrix KxKy = new DenseMatrix(m, m);
        Kx.mult(Ky, KxKy);
        DenseMatrix GxGy = new DenseMatrix(m, m);
        GGx.mult(GGy, GxGy);
        System.out.println("XY true: " + TestIndTestHsic.trace(KxKy, m));
        System.out.println("XY appr: " + TestIndTestHsic.trace(GxGy, m));
        UpperSymmPackMatrix H = KernelUtils.constructH(m);
        DenseMatrix HKx = new DenseMatrix(m, m);
        H.mult(Kx, HKx);
        DenseMatrix HKxH = new DenseMatrix(m, m);
        HKx.mult(H, HKxH);
        DenseMatrix HKy = new DenseMatrix(m, m);
        H.mult(Ky, HKy);
        DenseMatrix HKyH = new DenseMatrix(m, m);
        HKy.mult(H, HKyH);
        DenseMatrix HGx = new DenseMatrix(m, nx);
        H.mult(Gx, HGx);
        DenseMatrix HGxt = new DenseMatrix(nx, m);
        HGx.transpose(HGxt);
        DenseMatrix HGxH = new DenseMatrix(m, m);
        HGx.mult(HGxt, HGxH);
        DenseMatrix HGy = new DenseMatrix(m, ny);
        H.mult(Gy, HGy);
        DenseMatrix HGyt = new DenseMatrix(ny, m);
        HGy.transpose(HGyt);
        DenseMatrix HGyH = new DenseMatrix(m, m);
        HGy.mult(HGyt, HGyH);
        DenseMatrix HKxHHKyH = new DenseMatrix(m, m);
        HKxH.mult(HKyH, HKxHHKyH);
        DenseMatrix HGxHHGyH = new DenseMatrix(m, m);
        HGxH.mult(HGyH, HGxHHGyH);
        System.out.println("HXHHYH True: " + TestIndTestHsic.trace(HKxHHKyH, m));
        System.out.println("HXHHYH Appr: " + TestIndTestHsic.trace(HGxHHGyH, m));
        UpperSymmPackMatrix Kz = KernelUtils.constructGramMatrix(kernels, data, Arrays.asList(data.getVariable(2)));
        Matrix Gz = KernelUtils.incompleteCholeskyGramMatrix(kernels, data, Arrays.asList(data.getVariable(2)), precision);
        int nz = Gz.numColumns();
        DenseMatrix Gzt = new DenseMatrix(nz, m);
        Gz.transpose(Gzt);
        DenseMatrix GGz = new DenseMatrix(m, m);
        Gz.mult(Gzt, GGz);
        System.out.println("Z true: " + TestIndTestHsic.trace(Kz, m));
        System.out.println("Z appr: " + TestIndTestHsic.trace(GGz, m));
        double ep = 1.0E-4;
        Matrix Kzep = Kz.copy();
        for (int i = 0; i < m; ++i) {
            Kzep.set(i, i, Kzep.get(i, i) + ep);
        }
        DenseMatrix Kzinv = new DenseMatrix(m, m);
        Kzep.solve(Matrices.identity(m), Kzinv);
        System.out.println("Z-1 true: " + TestIndTestHsic.trace(Kzinv, m));
        Matrix GGzep = GGz.copy();
        for (int i = 0; i < m; ++i) {
            GGzep.set(i, i, GGzep.get(i, i) + ep);
        }
        DenseMatrix GGzinv = new DenseMatrix(m, m);
        GGzep.solve(Matrices.identity(m), GGzinv);
        System.out.println("Z-1 appr: " + TestIndTestHsic.trace(GGzinv, m));
    }

    public void testUnconditionalTest1() {
        ContinuousVariable X = new ContinuousVariable("X");
        ContinuousVariable Y = new ContinuousVariable("Y");
        int m = 160;
        double[] dataX = new double[]{9.6097, 9.5734, 9.6119, 9.65, 9.6467, 9.5675, 9.6127, 9.5662, 9.5577, 9.5815, 9.5905, 9.6488, 9.7119, 9.7658, 9.7814, 9.8079, 9.745, 9.747, 9.7644, 9.7007, 9.6996, 9.653, 9.6794, 9.6753, 9.6472, 9.6502, 9.6724, 9.6427, 9.5945, 9.5925, 9.6271, 9.6066, 9.7105, 9.7205, 9.7491, 9.7005, 9.7174, 9.7699, 9.7305, 9.758, 9.6953, 9.6892, 9.6939, 9.6392, 9.6838, 9.6438, 9.6749, 9.6771, 9.7014, 9.6352, 9.5624, 9.5488, 9.6184, 9.6343, 9.6572, 9.66, 9.6483, 9.6785, 9.7317, 9.7123, 9.7244, 9.7259, 9.6708, 9.589, 9.5632, 9.6279, 9.5785, 9.595, 9.5867, 9.5896, 9.5637, 9.5849, 9.6503, 9.7249, 9.6835, 9.6981, 9.7252, 9.6965, 9.7029, 9.687, 9.6454, 9.6177, 9.6176, 9.5261, 9.4988, 9.5184, 9.519, 9.5164, 9.5755, 9.6533, 9.6525, 9.5359, 9.5871, 9.6765, 9.7384, 9.7407, 9.7726, 9.7488, 9.744, 9.7602, 9.7204, 9.7315, 9.7178, 9.6471, 9.6016, 9.5629, 9.6143, 9.5382, 9.5847, 9.5789, 9.5852, 9.6261, 9.7179, 9.7832, 9.7592, 9.7289, 9.7363, 9.7234, 9.7751, 9.701, 9.7126, 9.7155, 9.6535, 9.6309, 9.5462, 9.5149, 9.5107, 9.6194, 9.5682, 9.539, 9.4894, 9.5402, 9.5876, 9.6345, 9.6908, 9.6962, 9.6318, 9.728, 9.7331, 9.6725, 9.6949, 9.7017, 9.683, 9.6156, 9.6932, 9.645, 9.7147, 9.6494, 9.5806, 9.5469, 9.556, 9.5647, 9.5961, 9.6443, 9.6466, 9.6939, 9.7393, 9.7166, 9.7222, 9.6618};
        double[] dataY = new double[]{-0.1154, -0.1328, -0.1498, -0.1663, -0.1821, -0.1972, -0.2115, -0.225, -0.2377, -0.2497, -0.2601, -0.1686, 0.1334, 0.4161, 0.5641, 0.6074, 0.5952, 0.5617, 0.5267, 0.4981, 0.476, 0.3586, 0.0369, -0.2604, -0.4208, -0.4765, -0.477, -0.4545, -0.4282, -0.4074, -0.3938, -0.2855, 0.0284, 0.3201, 0.4755, 0.5252, 0.5191, 0.4912, 0.4615, 0.4378, 0.4203, 0.3073, -0.0103, -0.3038, -0.4605, -0.5128, -0.5101, -0.4846, -0.4556, -0.4323, -0.4164, -0.3059, 0.0099, 0.3033, 0.4604, 0.5116, 0.5068, 0.4801, 0.4515, 0.4289, 0.4123, 0.3001, -0.0168, -0.3096, -0.4658, -0.5175, -0.5143, -0.4884, -0.459, -0.4353, -0.4191, -0.3084, 0.0078, 0.3015, 0.4588, 0.5104, 0.5058, 0.4794, 0.451, 0.4285, 0.4121, 0.3, -0.0167, -0.3094, -0.4656, -0.5174, -0.5142, -0.4883, -0.4589, -0.4353, -0.4191, -0.3084, 0.0078, 0.3015, 0.4588, 0.5104, 0.5058, 0.4794, 0.451, 0.4285, 0.4122, 0.3002, -0.0165, -0.3093, -0.4655, -0.5174, -0.5143, -0.4886, -0.4594, -0.4359, -0.4199, -0.3094, 0.0066, 0.3002, 0.4574, 0.5089, 0.5042, 0.4776, 0.4491, 0.4264, 0.4098, 0.2974, -0.0197, -0.3128, -0.4695, -0.5219, -0.5194, -0.4943, -0.4657, -0.443, -0.4278, -0.3182, -0.0033, 0.2891, 0.4449, 0.4948, 0.4885, 0.46, 0.4294, 0.4046, 0.3855, 0.2705, -0.0494, -0.3455, -0.5054, -0.5611, -0.5622, -0.5411, -0.5169, -0.4989, -0.4888, -0.3846, -0.0757, 0.21, 0.3585, 0.4003, 0.3852, 0.3475, 0.307, 0.2716};
        ColtDataSet dataSet = new ColtDataSet(m, Arrays.asList(X, Y));
        for (int i = 0; i < m; ++i) {
            dataSet.setDouble(i, 0, dataX[i]);
            dataSet.setDouble(i, 1, dataY[i]);
        }
        IndTestHsic test = new IndTestHsic(dataSet, 0.05);
        test.isIndependent((Node)X, (Node)Y, new ArrayList<Node>());
        System.out.println("HSIC P-value: " + test.getPValue());
        IndTestFisherZ test2 = new IndTestFisherZ(dataSet, 0.05);
        test2.isIndependent((Node)X, (Node)Y, new ArrayList<Node>());
        System.out.println("Fisher Z P-value: " + test2.getPValue());
    }

    public void testConditionalTest1() {
        double start = System.currentTimeMillis();
        double precision = 1.0E-18;
        ContinuousVariable A = new ContinuousVariable("A");
        ContinuousVariable B = new ContinuousVariable("B");
        ContinuousVariable C = new ContinuousVariable("C");
        Dag dag = new Dag(Arrays.asList(A, B, C));
        dag.addDirectedEdge(A, B);
        dag.addDirectedEdge(C, B);
        SemPm sem = new SemPm(dag);
        SemIm im = new SemIm(sem);
        DataSet data = im.simulateData(500, false);
        IndTestHsic test = new IndTestHsic(data, 0.05);
        test.setPerms(100);
        test.setIncompleteCholesky(precision);
        test.isIndependent((Node)A, (Node)C, Arrays.asList(B));
        System.out.println("HSIC P-value: " + test.getPValue());
        IndTestFisherZ test2 = new IndTestFisherZ(data, 0.05);
        test2.isIndependent((Node)A, (Node)C, Arrays.asList(B));
        System.out.println("Fisher Z P-value: " + test2.getPValue());
    }

    private static double trace(Matrix A, int m) {
        double trace = 0.0;
        for (int i = 0; i < m; ++i) {
            trace += A.get(i, i);
        }
        return trace;
    }
}

