/*
 * Decompiled with CFR 0.152.
 */
package jgpml.covariancefunctions;

import Jama.Matrix;
import java.util.Arrays;
import jgpml.covariancefunctions.CovarianceFunction;
import jgpml.covariancefunctions.MatrixOperations;
import org.apache.commons.math3.util.FastMath;

public class CovSEard
implements CovarianceFunction {
    private final int D;
    private final int numParameters;
    private Matrix K;

    public CovSEard(int inputDimension) {
        this.D = inputDimension;
        this.numParameters = this.D + 1;
    }

    private static Matrix squareDist(Matrix a) {
        return CovSEard.squareDist(a, a);
    }

    private static Matrix squareDist(Matrix a, Matrix b) {
        Matrix C = new Matrix(a.getColumnDimension(), b.getColumnDimension());
        int m = a.getColumnDimension();
        int n = b.getColumnDimension();
        int d = a.getRowDimension();
        for (int i = 0; i < m; ++i) {
            for (int j = 0; j < n; ++j) {
                double z = 0.0;
                for (int k = 0; k < d; ++k) {
                    double t = a.get(k, i) - b.get(k, j);
                    z += t * t;
                }
                C.set(i, j, z);
            }
        }
        return C;
    }

    @Override
    public int numParameters() {
        return this.numParameters;
    }

    @Override
    public Matrix compute(Matrix loghyper, Matrix X) {
        if (X.getColumnDimension() != this.D) {
            throw new IllegalArgumentException("The number of dimensions specified on the covariance function " + this.D + " must agree with the size of the input vector" + X.getColumnDimension());
        }
        if (loghyper.getColumnDimension() != 1 || loghyper.getRowDimension() != this.numParameters) {
            throw new IllegalArgumentException("Wrong number of hyperparameters, " + loghyper.getRowDimension() + " instead of " + this.numParameters);
        }
        Matrix ell = MatrixOperations.exp(loghyper.getMatrix(0, this.D - 1, 0, 0));
        double sf2 = FastMath.exp(2.0 * loghyper.get(this.D, 0));
        Matrix diag = new Matrix(this.D, this.D);
        for (int i = 0; i < this.D; ++i) {
            diag.set(i, i, 1.0 / ell.get(i, 0));
        }
        this.K = MatrixOperations.exp(CovSEard.squareDist(diag.times(X.transpose())).times(-0.5)).times(sf2);
        return this.K;
    }

    @Override
    public Matrix[] compute(Matrix loghyper, Matrix X, Matrix Xstar) {
        if (X.getColumnDimension() != this.D) {
            throw new IllegalArgumentException("The number of dimensions specified on the covariance function " + this.D + " must agree with the size of the input vector" + X.getColumnDimension());
        }
        if (loghyper.getColumnDimension() != 1 || loghyper.getRowDimension() != this.numParameters) {
            throw new IllegalArgumentException("Wrong number of hyperparameters, " + loghyper.getRowDimension() + " instead of " + this.numParameters);
        }
        Matrix ell = MatrixOperations.exp(loghyper.getMatrix(0, this.D - 1, 0, 0));
        double sf2 = FastMath.exp(2.0 * loghyper.get(this.D, 0));
        double[] a = new double[Xstar.getRowDimension()];
        Arrays.fill(a, sf2);
        Matrix A = new Matrix(a, Xstar.getRowDimension());
        Matrix diag = new Matrix(this.D, this.D);
        for (int i = 0; i < this.D; ++i) {
            diag.set(i, i, 1.0 / ell.get(i, 0));
        }
        Matrix B = MatrixOperations.exp(CovSEard.squareDist(diag.times(X.transpose()), diag.times(Xstar.transpose())).times(-0.5)).times(sf2);
        return new Matrix[]{A, B};
    }

    @Override
    public Matrix computeDerivatives(Matrix loghyper, Matrix X, int index) {
        if (X.getColumnDimension() != this.D) {
            throw new IllegalArgumentException("The number of dimensions specified on the covariance function " + this.D + " must agree with the size of the input vector" + X.getColumnDimension());
        }
        if (loghyper.getColumnDimension() != 1 || loghyper.getRowDimension() != this.numParameters) {
            throw new IllegalArgumentException("Wrong number of hyperparameters, " + loghyper.getRowDimension() + " instead of " + this.numParameters);
        }
        if (index > this.numParameters() - 1) {
            throw new IllegalArgumentException("Wrong hyperparameters index " + index + " it should be smaller or equal to " + (this.numParameters() - 1));
        }
        Matrix A = null;
        Matrix ell = MatrixOperations.exp(loghyper.getMatrix(0, this.D - 1, 0, 0));
        double sf2 = FastMath.exp(2.0 * loghyper.get(this.D, 0));
        if (this.K.getRowDimension() != X.getRowDimension() || this.K.getColumnDimension() != X.getRowDimension()) {
            Matrix diag = new Matrix(this.D, this.D);
            for (int i = 0; i < this.D; ++i) {
                diag.set(i, i, 1.0 / ell.get(i, 0));
            }
            this.K = MatrixOperations.exp(CovSEard.squareDist(diag.times(X.transpose())).times(-0.5)).times(sf2);
        }
        if (index < this.D) {
            Matrix col = CovSEard.squareDist(X.getMatrix(0, X.getRowDimension() - 1, index, index).transpose().times(1.0 / ell.get(index, 0)));
            A = this.K.arrayTimes(col);
        } else {
            A = this.K.times(2.0);
            this.K = null;
        }
        return A;
    }
}

