/*
 * Decompiled with CFR 0.152.
 */
package edu.cmu.tetrad.search.kernel;

import edu.cmu.tetrad.data.DataSet;
import edu.cmu.tetrad.graph.Node;
import edu.cmu.tetrad.search.kernel.Kernel;
import edu.cmu.tetrad.util.Matrix;
import java.util.List;
import org.apache.commons.math3.util.FastMath;

public class KernelUtils {
    public static Matrix constructGramMatrix(List<Kernel> kernels, DataSet dataset, List<Node> nodes) {
        int m = dataset.getNumRows();
        Matrix gram = new Matrix(m, m);
        for (int k = 0; k < nodes.size(); ++k) {
            Node node = nodes.get(k);
            int col = dataset.getColumn(node);
            Kernel kernel = kernels.get(k);
            for (int i = 0; i < m; ++i) {
                for (int j = i; j < m; ++j) {
                    double keval = kernel.eval(dataset.getDouble(i, col), dataset.getDouble(j, col));
                    if (k != 0) {
                        keval *= gram.get(i, j);
                    }
                    gram.set(i, j, keval);
                }
            }
        }
        return gram;
    }

    public static Matrix constructCentralizedGramMatrix(List<Kernel> kernels, DataSet dataset, List<Node> nodes) {
        int m = dataset.getNumRows();
        Matrix gram = KernelUtils.constructGramMatrix(kernels, dataset, nodes);
        Matrix H = KernelUtils.constructH(m);
        Matrix KH = gram.times(H);
        return H.times(KH);
    }

    public static Matrix constructH(int m) {
        Matrix H = new Matrix(m, m);
        double od = -1.0 / (double)m;
        double d = od + 1.0;
        for (int i = 0; i < m; ++i) {
            for (int j = i; j < m; ++j) {
                if (i == j) {
                    H.set(i, j, d);
                    continue;
                }
                H.set(i, j, od);
            }
        }
        return H;
    }

    public static Matrix incompleteCholeskyGramMatrix(List<Kernel> kernels, DataSet dataset, List<Node> nodes, double precision) {
        if (precision <= 0.0) {
            throw new IllegalArgumentException("Precision must be > 0");
        }
        int m = dataset.getNumRows();
        Matrix G = new Matrix(m, m);
        double[] Dadv = new double[m];
        int[] p = new int[m];
        for (int i = 0; i < m; ++i) {
            Dadv[i] = KernelUtils.evaluate(kernels, dataset, nodes, i, i);
            p[i] = i;
        }
        int cols = m;
        for (int k = 0; k < m; ++k) {
            int j;
            double best = Dadv[k];
            int bestInd = k;
            for (int j2 = k + 1; j2 < m; ++j2) {
                if (!(Dadv[j2] > best / 0.99)) continue;
                best = Dadv[j2];
                bestInd = j2;
            }
            if (best < precision) {
                cols = k - 1;
                break;
            }
            int pk = p[k];
            p[k] = p[bestInd];
            p[bestInd] = pk;
            double dk = Dadv[k];
            Dadv[k] = Dadv[bestInd];
            Dadv[bestInd] = dk;
            for (int j3 = 0; j3 < k; ++j3) {
                double gk = G.get(k, j3);
                G.set(k, j3, G.get(bestInd, j3));
                G.set(bestInd, j3, gk);
            }
            double diag = FastMath.sqrt(Dadv[k]);
            G.set(k, k, diag);
            for (j = k + 1; j < m; ++j) {
                double s = 0.0;
                for (int i = 0; i < k; ++i) {
                    s += G.get(j, i) * G.get(k, i);
                }
                G.set(j, k, (KernelUtils.evaluate(kernels, dataset, nodes, p[j], p[k]) - s) / diag);
            }
            for (j = k + 1; j < m; ++j) {
                int n = j;
                Dadv[n] = Dadv[n] - FastMath.pow(G.get(j, k), 2);
            }
            Dadv[k] = 0.0;
        }
        Matrix Gm = new Matrix(m, cols);
        for (int i = 0; i < m; ++i) {
            for (int j = 0; j < cols; ++j) {
                Gm.set(i, j, G.get(i, j));
            }
        }
        return Gm;
    }

    private static double evaluate(List<Kernel> kernels, DataSet dataset, List<Node> vars, int i, int j) {
        int col = dataset.getColumn(vars.get(0));
        double keval = kernels.get(0).eval(dataset.getDouble(i, col), dataset.getDouble(j, col));
        for (int k = 1; k < vars.size(); ++k) {
            col = dataset.getColumn(vars.get(k));
            keval *= kernels.get(k).eval(dataset.getDouble(i, col), dataset.getDouble(j, col));
        }
        return keval;
    }
}

