#![allow(dead_code)]

use nalgebra::*;

use super::basic::{_Rm2, _cov_xv_3};

/// Naive implementation.
/// Not accurate for small h.
/// NaN for large h.
#[inline(always)]
pub fn _cov_xvax_1<N: RealField>(h: N, gamma: N, a: N) -> (N, N, N) {
    let delta = h * gamma;
    let two = convert::<_, N>(2.0);

    let eta1 = (-delta).exp();
    let aeta1 = (-a * delta).exp();

    let cov_xax = (a * delta
        + (aeta1 - N::one())
        + (eta1 - ((-(N::one() + a) * delta).exp() + (-(N::one() - a) * delta).exp()) / two))
        / gamma.powi(3);
    let t = ((a - N::one()) / two * delta).exp() - (-(a + N::one()) / two * delta).exp();
    let cov_vax = t * t / two / gamma / gamma;

    let cov_axax = (N::one() + two * a * delta - (two - aeta1).powi(2)) / (two * gamma.powi(3));
    (cov_xax, cov_vax, cov_axax)
}
/// More accurate.
#[inline(always)]
#[allow(non_snake_case)]
pub fn _cov_xvax_2<N: RealField>(h: N, gamma: N, a: N) -> (N, N, N) {
    let delta = h * gamma;
    let two = convert::<_, N>(2.0);

    if a * delta > N::ln(two) {
        _cov_xvax_1(h, gamma, a)
    } else {
        let eta1 = (-delta).exp();
        let aRm2 = _Rm2(a * delta);
        let ac = -a * delta - aRm2;
        let apc = a * delta - aRm2;

        let acog = ac / gamma;
        let apcog = apc / gamma;
        let t2 = two - ac;
        let t3 = two - apc;
        let att = t2 * t2;

        let cov_xax;
        if delta > N::ln(two) {
            cov_xax = a * h * (eta1 * (acog - apcog) / t2 / t3 - acog / t2) / gamma;
        } else {
            let Rm2 = _Rm2(delta);
            let c = -delta - Rm2;
            let t1 = two - c;

            cov_xax = a * h * (two * aRm2 / gamma / gamma + apcog * acog - two * h * (acog - apcog) / t1) / t2 / t3;
        }
        let cov_vax = a * h * eta1 * (apcog - acog) / t2 / t3;
        let cov_axax = a * h * (acog * acog + two * aRm2 / gamma / gamma) / att;
        (cov_xax, cov_vax, cov_axax)
    }
}
/// Taylor expansion
#[inline(always)]
pub fn _cov_xvax_3<N: RealField>(h: N, gamma: N, a: N) -> (N, N, N) {
    let delta = h * gamma;

    if delta > convert::<_, N>(1e-1) {
        _cov_xvax_2(h, gamma, a)
    } else {
        let two = convert::<_, N>(2.0);
        let cov_xax = {
            let t = N::one() - a / convert::<_, N>(5.0);
            let t = N::one() + a * a / two * t;
            let t = N::one() - delta / convert::<_, N>(3.0) * t;
            let t = N::one() - a / convert::<_, N>(3.0) - delta / two * t;
            h.powi(3) * a * a / two * t
        };
        let cov_vax = {
            let t = convert::<_, N>(1.0 / 2.0) + a * a * convert::<_, N>(1.0 / 12.0);
            let t = N::one() - delta * t;
            let t = N::one() - delta * t;
            let t = a * a * h * h / two * t;
            t
        };
        let cov_axax = {
            let adelta = delta * a;
            let ah = h * a;
            let t = N::one() - adelta * convert::<_, N>(7.0 / 15.0);
            let t = convert::<_, N>(1.0 / 3.0) - t * adelta / convert::<_, N>(4.0);
            let t = ah * ah * ah * t;
            t
        };
        (cov_xax, cov_vax, cov_axax)
    }
}
#[inline(always)]
pub fn cov_xvax<N: RealField>(h: N, gamma: N, a: N) -> OMatrix<N, Const<3>, Const<3>> {
    let (cov_xx, cov_xv, cov_vv) = _cov_xv_3(h, gamma);
    let (cov_xax, cov_vax, cov_axax) = _cov_xvax_3(h, gamma, a);
    let cov = Matrix3::from([
        [cov_xx, cov_xv, cov_xax],
        [cov_xv, cov_vv, cov_vax],
        [cov_xax, cov_vax, cov_axax],
    ]);
    cov
}
/// SVD
#[inline(always)]
pub fn sqrt_cov_xvax_1<N: RealField>(h: N, gamma: N, a: N) -> OMatrix<N, Const<3>, Const<3>> {
    let cov = cov_xvax(h, gamma, a);
    let mut svd = cov.svd(true, true);
    svd.singular_values = svd.singular_values.map(|x| x.sqrt());
    let sqrt_cov = svd.recompose().unwrap();
    sqrt_cov
}
/// Direct sqrt 3x3 matrix
#[inline(always)]
#[allow(non_snake_case)]
pub fn sqrt_cov_xvax_2<N: RealField>(h: N, gamma: N, a: N) -> OMatrix<N, Const<3>, Const<3>> {
    let two = convert::<_, N>(2.0);

    let C = cov_xvax(h, gamma, a);
    let C2 = C * C;

    let IC = C.trace();
    let IIC = (IC * IC - C2.trace()) / two;
    let IIIC = C.determinant();
    let k = IC * IC - convert::<_, N>(3.0) * IIC;
    if k == N::zero() {
        return Matrix3::zeros();
    }

    let l = IC * IC * (IC - convert::<_, N>(9.0 / 2.0) * IIC) + convert::<_, N>(27.0 / 2.0) * IIIC;
    let phi = ((l / k / k.sqrt()).min(N::one())).acos();
    let lambda = ((IC + two * k.sqrt() * (phi / convert::<_, N>(3.0)).cos()) / convert::<_, N>(3.0)).sqrt();

    let IIIU = IIIC.sqrt();
    let IU = lambda + (-lambda * lambda + IC + two * IIIU / lambda).max(N::zero()).sqrt();
    let IIU = (IU * IU - IC) / two;

    (Matrix3::identity() * (IU * IIIU) + C * (IU * IU - IIU) - C2) / (IU * IIU - IIIU)
}
/// Translated Cholesky decomposition
/// decompose v variance first.
#[inline(always)]
pub fn _sqrt_cov_xvax_3<N: RealField>(
    cov_xx: N,
    cov_xv: N,
    cov_vv: N,
    cov_xax: N,
    cov_vax: N,
    cov_axax: N,
) -> OMatrix<N, Const<3>, Const<3>> {
    let (e, d, f) = (cov_xax, cov_vax, cov_axax);
    let (c, b, a) = (cov_xx, cov_xv, cov_vv);

    let ap = a.sqrt();
    let bp;
    let dp;
    if a == N::zero() {
        bp = N::zero();
        dp = N::zero();
    } else {
        bp = b / ap;
        dp = d / ap;
    }
    let cp = (c - bp * bp).max(N::zero()).sqrt();
    let ep;
    if cp == N::zero() {
        ep = N::zero();
    } else {
        ep = (e - bp * dp) / cp;
    }
    let fp = (f - dp * dp - ep * ep).max(N::zero()).sqrt();
    // OMatrix assume column-major 2D array by default.
    // https://nalgebra.org/docs/user_guide/vectors_and_matrices
    Matrix3::from([[bp, ap, dp], [cp, N::zero(), ep], [N::zero(), N::zero(), fp]])
}
#[inline(always)]
pub fn sqrt_cov_xvax_3<N: RealField>(h: N, gamma: N, a: N) -> OMatrix<N, Const<3>, Const<3>> {
    let (cov_xax, cov_vax, cov_axax) = _cov_xvax_3(h, gamma, a);
    let (cov_xx, cov_xv, cov_vv) = _cov_xv_3(h, gamma);
    _sqrt_cov_xvax_3(cov_xx, cov_xv, cov_vv, cov_xax, cov_vax, cov_axax)
}

#[cfg(test)]
mod tests {
    #![allow(unused)]
    use approx::assert_relative_eq;
    use nalgebra::*;
    use rand::prelude::*;

    use super::{_cov_xvax_1, _cov_xvax_2, _cov_xvax_3, cov_xvax, sqrt_cov_xvax_1, sqrt_cov_xvax_2, sqrt_cov_xvax_3};

    #[test]
    fn test_cov_xvax() {
        type DType = f32;
        for h in (-70..20).map(|x| (2.0).powi(x)) {
            for gamma in [1e-9, 0.1, 2.0, 1e4].iter().cloned() {
                for a in [0.0, 1e-9, 0.01, 0.1, 0.3, 0.5, 0.7, 0.9, 0.99, 1.0 - 1e-9, 1.0]
                    .iter()
                    .cloned()
                {
                    let r1 = _cov_xvax_1::<DType>(h, gamma, a);
                    let r2 = _cov_xvax_2::<DType>(h, gamma, a);
                    let r3 = _cov_xvax_3::<DType>(h, gamma, a);
                    if (gamma > 1e-2 && h > 0.1) {
                        assert_relative_eq!(r1.0, r2.0, epsilon = 1e-4, max_relative = 0.2);
                        assert_relative_eq!(r1.1, r2.1, epsilon = 1e-4, max_relative = 0.2);
                        assert_relative_eq!(r1.2, r2.2, epsilon = 1e-4, max_relative = 0.2);
                    }
                    assert_relative_eq!(r2.0, r3.0, epsilon = DType::MIN_POSITIVE * 10.0, max_relative = 1e-3);
                    assert_relative_eq!(r2.1, r3.1, epsilon = DType::MIN_POSITIVE * 10.0, max_relative = 1e-3);
                    assert_relative_eq!(r2.2, r3.2, epsilon = DType::MIN_POSITIVE * 10.0, max_relative = 1e-3);
                }
            }
        }
    }
    #[test]
    fn test_sqrt_cov_xvax() {
        type DType = f32;

        for h in (-70..20).map(|x| (2.0).powi(x)) {
            for gamma in [1e-9, 0.1, 2.0, 1e4].iter().cloned() {
                for a in [0.0, 1e-9, 0.01, 0.1, 0.3, 0.5, 0.7, 0.9, 0.99, 1.0 - 1e-9, 1.0]
                    .iter()
                    .cloned()
                {
                    let cov = cov_xvax::<DType>(h, gamma, a);
                    let r1 = sqrt_cov_xvax_1::<DType>(h, gamma, a);
                    let r2 = sqrt_cov_xvax_2::<DType>(h, gamma, a);
                    let r3 = sqrt_cov_xvax_3::<DType>(h, gamma, a);

                    assert_relative_eq!(cov, r1 * r1, epsilon = 1e-1, max_relative = 1e0);
                    // Could be NaN
                    // assert_relative_eq!(cov, r2 * r2, epsilon = DType::MIN_POSITIVE * 6e3, max_relative = 9e1);
                    assert_relative_eq!(
                        cov,
                        r3 * r3.transpose(),
                        epsilon = DType::MIN_POSITIVE * 10.0,
                        max_relative = 1e-3
                    );
                }
            }
        }
    }
}
