/*
 * Decompiled with CFR 0.152.
 */
package jgpml;

import Jama.CholeskyDecomposition;
import Jama.Matrix;
import jgpml.CSVtoMatrix;
import jgpml.covariancefunctions.CovLINone;
import jgpml.covariancefunctions.CovNoise;
import jgpml.covariancefunctions.CovSum;
import jgpml.covariancefunctions.CovarianceFunction;
import org.apache.commons.math3.util.FastMath;

public class GaussianProcess {
    public Matrix logtheta;
    public Matrix X;
    public Matrix L;
    public Matrix alpha;
    CovarianceFunction covFunction;
    private static final double INT = 0.1;
    private static final double EXT = 3.0;
    private static final int MAX = 20;
    private static final double RATIO = 10.0;
    private static final double SIG = 0.1;
    private static final double RHO = 0.05;

    public GaussianProcess(CovarianceFunction covFunction) {
        this.covFunction = covFunction;
    }

    public void train(Matrix X, Matrix y, Matrix logtheta0) {
        this.train(X, y, logtheta0, -100);
    }

    public void train(Matrix X, Matrix y, Matrix logtheta0, int iterations) {
        System.out.println("training started...");
        this.X = X;
        this.logtheta = this.minimize(logtheta0, iterations, X, y);
    }

    public double negativeLogLikelihood(Matrix logtheta, Matrix x, Matrix y, Matrix df0) {
        int n = x.getRowDimension();
        Matrix K = this.covFunction.compute(logtheta, x);
        CholeskyDecomposition cd = K.chol();
        if (!cd.isSPD()) {
            throw new RuntimeException("The covariance Matrix is not SDP, check your covariance function (maybe you mess the noise term..)");
        }
        this.L = cd.getL();
        this.alpha = GaussianProcess.bSubstitutionWithTranspose(this.L, GaussianProcess.fSubstitution(this.L, y));
        double lml = y.transpose().times(this.alpha).times(0.5).get(0, 0);
        for (int i = 0; i < this.L.getRowDimension(); ++i) {
            lml += FastMath.log(this.L.get(i, i));
        }
        lml += 0.5 * (double)n * FastMath.log(Math.PI * 2);
        Matrix W = GaussianProcess.bSubstitutionWithTranspose(this.L, GaussianProcess.fSubstitution(this.L, Matrix.identity(n, n))).minus(this.alpha.times(this.alpha.transpose()));
        for (int i = 0; i < df0.getRowDimension(); ++i) {
            df0.set(i, 0, GaussianProcess.sum(W.arrayTimes(this.covFunction.computeDerivatives(logtheta, x, i))) / 2.0);
        }
        return lml;
    }

    public Matrix[] predict(Matrix xstar) {
        if (this.alpha == null || this.L == null) {
            System.out.println("GP needs to be trained first..");
            System.exit(-1);
        }
        if (xstar.getColumnDimension() != this.X.getColumnDimension()) {
            throw new IllegalArgumentException("Wrong size of the input " + xstar.getColumnDimension() + " instead of " + this.X.getColumnDimension());
        }
        Matrix[] star = this.covFunction.compute(this.logtheta, this.X, xstar);
        Matrix Kstar = star[1];
        Matrix Kss = star[0];
        Matrix ystar = Kstar.transpose().times(this.alpha);
        Matrix v = GaussianProcess.fSubstitution(this.L, Kstar);
        v.arrayTimesEquals(v);
        Matrix Sstar = Kss.minus(GaussianProcess.sumColumns(v).transpose());
        return new Matrix[]{ystar, Sstar};
    }

    public Matrix predictMean(Matrix xstar) {
        if (this.alpha == null || this.L == null) {
            System.out.println("GP needs to be trained first..");
            System.exit(-1);
        }
        if (xstar.getColumnDimension() != this.X.getColumnDimension()) {
            throw new IllegalArgumentException("Wrong size of the input" + xstar.getColumnDimension() + " instead of " + this.X.getColumnDimension());
        }
        Matrix[] star = this.covFunction.compute(this.logtheta, this.X, xstar);
        Matrix Kstar = star[1];
        return Kstar.transpose().times(this.alpha);
    }

    private static Matrix sumColumns(Matrix a) {
        Matrix sum = new Matrix(1, a.getColumnDimension());
        for (int i = 0; i < a.getRowDimension(); ++i) {
            sum.plusEquals(a.getMatrix(i, i, 0, a.getColumnDimension() - 1));
        }
        return sum;
    }

    private static double sum(Matrix a) {
        double sum = 0.0;
        for (int i = 0; i < a.getRowDimension(); ++i) {
            for (int j = 0; j < a.getColumnDimension(); ++j) {
                sum += a.get(i, j);
            }
        }
        return sum;
    }

    private static Matrix fSubstitution(Matrix L, Matrix B) {
        double[][] l = L.getArray();
        double[][] b = B.getArray();
        double[][] x = new double[B.getRowDimension()][B.getColumnDimension()];
        int n = x.length;
        for (int i = 0; i < B.getColumnDimension(); ++i) {
            for (int k = 0; k < n; ++k) {
                x[k][i] = b[k][i];
                for (int j = 0; j < k; ++j) {
                    double[] dArray = x[k];
                    int n2 = i;
                    dArray[n2] = dArray[n2] - l[k][j] * x[j][i];
                }
                double[] dArray = x[k];
                int n3 = i;
                dArray[n3] = dArray[n3] / l[k][k];
            }
        }
        return new Matrix(x);
    }

    private static Matrix bSubstitution(Matrix L, Matrix B) {
        double[][] l = L.getArray();
        double[][] b = B.getArray();
        double[][] x = new double[B.getRowDimension()][B.getColumnDimension()];
        int n = x.length - 1;
        for (int i = 0; i < B.getColumnDimension(); ++i) {
            for (int k = n; k > -1; --k) {
                x[k][i] = b[k][i];
                for (int j = n; j > k; --j) {
                    double[] dArray = x[k];
                    int n2 = i;
                    dArray[n2] = dArray[n2] - l[k][j] * x[j][i];
                }
                double[] dArray = x[k];
                int n3 = i;
                dArray[n3] = dArray[n3] / l[k][k];
            }
        }
        return new Matrix(x);
    }

    private static Matrix bSubstitutionWithTranspose(Matrix L, Matrix B) {
        double[][] l = L.getArray();
        double[][] b = B.getArray();
        double[][] x = new double[B.getRowDimension()][B.getColumnDimension()];
        int n = x.length - 1;
        for (int i = 0; i < B.getColumnDimension(); ++i) {
            for (int k = n; k > -1; --k) {
                x[k][i] = b[k][i];
                for (int j = n; j > k; --j) {
                    double[] dArray = x[k];
                    int n2 = i;
                    dArray[n2] = dArray[n2] - l[j][k] * x[j][i];
                }
                double[] dArray = x[k];
                int n3 = i;
                dArray[n3] = dArray[n3] / l[k][k];
            }
        }
        return new Matrix(x);
    }

    private Matrix minimize(Matrix params, int length, Matrix in, Matrix out) {
        double red = 1.0;
        int i = 0;
        boolean ls_failed = false;
        int sizeX = params.getRowDimension();
        Matrix df0 = new Matrix(sizeX, 1);
        double f0 = this.negativeLogLikelihood(params, in, out, df0);
        Matrix fX = new Matrix(new double[]{f0}, 1);
        i = length < 0 ? i + 1 : i;
        Matrix s = df0.times(-1.0);
        double d0 = s.times(-1.0).transpose().times(s).get(0, 0);
        double x3 = 1.0 / (1.0 - d0);
        int nCycles = FastMath.abs(length);
        while (i < nCycles) {
            double A;
            double B;
            double d3;
            Matrix m1;
            Matrix df3;
            double f3;
            double d2;
            double f2;
            double x2;
            i = length > 0 ? i + 1 : i;
            double F0 = f0;
            Matrix X0 = params.copy();
            Matrix dF0 = df0.copy();
            double M = length > 0 ? 20.0 : (double)FastMath.min(20, -length - i);
            while (true) {
                x2 = 0.0;
                f2 = f0;
                d2 = d0;
                f3 = f0;
                df3 = df0.copy();
                boolean success = false;
                while (!success && M > 0.0) {
                    M -= 1.0;
                    i = length < 0 ? i + 1 : i;
                    m1 = params.plus(s.times(x3));
                    f3 = this.negativeLogLikelihood(m1, in, out, df3);
                    if (Double.isNaN(f3) || Double.isInfinite(f3) || GaussianProcess.hasInvalidNumbers(df3.getRowPackedCopy())) {
                        x3 = (x2 + x3) / 2.0;
                        continue;
                    }
                    success = true;
                }
                if (f3 < F0) {
                    X0 = s.times(x3).plus(params);
                    F0 = f3;
                    dF0 = df3;
                }
                if ((d3 = df3.transpose().times(s).get(0, 0)) > 0.1 * d0 || f3 > f0 + x3 * 0.05 * d0 || M == 0.0) break;
                double x1 = x2;
                double d1 = d2;
                f2 = f3;
                double f1 = f2;
                B = 3.0 * (f2 - f1) - (2.0 * d1 + (d2 = d3)) * (x2 - x1);
                A = 6.0 * (f1 - f2) + 3.0 * (d2 + d1) * (x2 - x1);
                if (Double.isNaN(x3 = x1 - d1 * ((x2 = x3) - x1) * (x2 - x1) / (B + FastMath.sqrt(B * B - A * d1 * (x2 - x1)))) || Double.isInfinite(x3) || x3 < 0.0) {
                    x3 = x2 * 3.0;
                    continue;
                }
                if (x3 > x2 * 3.0) {
                    x3 = x2 * 3.0;
                    continue;
                }
                if (!(x3 < x2 + 0.1 * (x2 - x1))) continue;
                x3 = x2 + 0.1 * (x2 - x1);
            }
            double f4 = 0.0;
            double x4 = 0.0;
            double d4 = 0.0;
            while ((FastMath.abs(d3) > -0.1 * d0 || f3 > f0 + x3 * 0.05 * d0) && M > 0.0) {
                if (d3 > 0.0 || f3 > f0 + x3 * 0.05 * d0) {
                    x4 = x3;
                    f4 = f3;
                    d4 = d3;
                } else {
                    x2 = x3;
                    f2 = f3;
                    d2 = d3;
                }
                if (f4 > f0) {
                    x3 = x2 - 0.5 * d2 * (x4 - x2) * (x4 - x2) / (f4 - f2 - d2 * (x4 - x2));
                } else {
                    A = 6.0 * (f2 - f4) / (x4 - x2) + 3.0 * (d4 + d2);
                    B = 3.0 * (f4 - f2) - (2.0 * d2 + d4) * (x4 - x2);
                    x3 = x2 + (FastMath.sqrt(B * B - A * d2 * (x4 - x2) * (x4 - x2)) - B) / A;
                }
                if (Double.isNaN(x3) || Double.isInfinite(x3)) {
                    x3 = (x2 + x4) / 2.0;
                }
                if ((f3 = this.negativeLogLikelihood(m1 = s.times(x3 = FastMath.max(FastMath.min(x3, x4 - 0.1 * (x4 - x2)), x2 + 0.1 * (x4 - x2))).plus(params), in, out, df3)) < F0) {
                    X0 = m1.copy();
                    F0 = f3;
                    dF0 = df3.copy();
                }
                M -= 1.0;
                i = length < 0 ? i + 1 : i;
                d3 = df3.transpose().times(s).get(0, 0);
            }
            if (FastMath.abs(d3) < -0.1 * d0 && f3 < f0 + x3 * 0.05 * d0) {
                params = s.times(x3).plus(params);
                f0 = f3;
                double[] elem = fX.getColumnPackedCopy();
                double[] newfX = new double[elem.length + 1];
                System.arraycopy(elem, 0, newfX, 0, elem.length);
                newfX[elem.length - 1] = f0;
                fX = new Matrix(newfX, newfX.length);
                System.out.println("Function evaluation " + i + " Value " + f0);
                double tmp1 = df3.transpose().times(df3).minus(df0.transpose().times(df3)).get(0, 0);
                double tmp2 = df0.transpose().times(df0).get(0, 0);
                s = s.times(tmp1 / tmp2).minus(df3);
                df0 = df3;
                d3 = d0;
                d0 = df0.transpose().times(s).get(0, 0);
                if (d0 > 0.0) {
                    s = df0.times(-1.0);
                    d0 = s.times(-1.0).transpose().times(s).get(0, 0);
                }
                x3 *= FastMath.min(10.0, d3 / (d0 - Double.MIN_VALUE));
                ls_failed = false;
                continue;
            }
            params = X0;
            f0 = F0;
            df0 = dF0;
            if (ls_failed || i > FastMath.abs(length)) break;
            s = df0.times(-1.0);
            d0 = s.times(-1.0).transpose().times(s).get(0, 0);
            x3 = 1.0 / (1.0 - d0);
            ls_failed = true;
        }
        return params;
    }

    private static boolean hasInvalidNumbers(double[] array) {
        for (double a : array) {
            if (!Double.isInfinite(a) && !Double.isNaN(a)) continue;
            return true;
        }
        return false;
    }

    public static void main(String[] args) {
        CovSum covFunc = new CovSum(6, new CovLINone(), new CovNoise());
        GaussianProcess gp = new GaussianProcess(covFunc);
        double[][] logtheta0 = new double[][]{{0.1}, {FastMath.log(0.1)}};
        Matrix params0 = new Matrix(logtheta0);
        Matrix[] data = CSVtoMatrix.load("../armdata.csv", 6, 1);
        Matrix X = data[0];
        Matrix Y = data[1];
        gp.train(X, Y, params0, -20);
        Matrix[] datastar = CSVtoMatrix.load("../armdatastar.csv", 6, 1);
        Matrix Xstar = datastar[0];
        Matrix Ystar = datastar[1];
        Matrix[] res = gp.predict(Xstar);
        res[0].print(res[0].getColumnDimension(), 16);
        res[1].print(res[1].getColumnDimension(), 16);
    }
}

