#![allow(dead_code)]

use nalgebra::*;

/// R = r*(exp(r)+1)/(exp(r)-1)
/// c = 2-R+r
/// exp(r) = 1+r+r*c/(2-c)
/// require abs(r) to be smaller than ln2.
/// ref: FreeBSD /usr/src/lib/msun/src/e_exp.c
#[inline(always)]
#[allow(non_snake_case)]
pub fn _Rm2<N: RealField>(r: N) -> N {
    let P1 = convert::<_, N>(1.66666666666666019037e-01);
    let P2 = convert::<_, N>(-2.77777777770155933842e-03);
    let P3 = convert::<_, N>(6.61375632143793436117e-05);
    let P4 = convert::<_, N>(-1.65339022054652515390e-06);
    let P5 = convert::<_, N>(4.13813679705723846039e-08);

    let xx = r * r;
    let Rm2 = xx * (P1 + xx * (P2 + xx * (P3 + xx * (P4 + xx * P5))));
    Rm2
}

/// Naive implementation.
/// Not accurate for small h.
#[inline(always)]
pub fn psi_1<N: RealField>(h: N, gamma: N, order: usize) -> N {
    let delta = h * gamma;
    let eta1 = (-delta).exp();
    match order {
        0 => eta1,
        1 => (N::one() - eta1) / gamma,
        2 => (delta - N::one() + eta1) / gamma / gamma,
        _ => unimplemented!(),
    }
}
/// Accurate.
#[inline(always)]
#[allow(non_snake_case)]
pub fn psi_2<N: RealField>(h: N, gamma: N, order: usize) -> N {
    let delta = h * gamma;
    let two = convert::<_, N>(2.0);
    if delta > N::ln(two) {
        psi_1(h, gamma, order)
    } else {
        let Rm2 = _Rm2(delta);
        let c = -delta - Rm2;
        match order {
            0 => N::one() - two * delta / (two - c),
            1 => two * h / (two - c),
            2 => -h * (c / gamma) / (two - c),
            _ => unimplemented!(),
        }
    }
}
/// Taylor expansion
/// h should be positive
#[inline(always)]
pub fn psi_3<N: RealField>(h: N, gamma: N, order: usize) -> N {
    let delta = h * gamma;
    // Apart from 0.1% accuracy, we can also ensure accuracy to the minimum representable difference around 1.0.
    // if delta > N::default_epsilon().powf(convert::<_, N>(1.0 / 3.0)) {
    if delta > convert::<_, N>(1e-1) {
        psi_2(h, gamma, order)
    } else {
        match order {
            0 => {
                let t = N::one() - delta / convert::<_, N>(4.0);
                let t = N::one() - t * delta / convert::<_, N>(3.0);
                let t = N::one() - t * delta / convert::<_, N>(2.0);
                let t = N::one() - t * delta;
                t
            }
            1 => {
                let t = N::one() - delta / convert::<_, N>(4.0);
                let t = N::one() - t * delta / convert::<_, N>(3.0);
                let t = N::one() - t * delta / convert::<_, N>(2.0);
                let t = h * t;
                t
            }
            2 => {
                let t = N::one() - delta / convert::<_, N>(4.0);
                let t = N::one() - t * delta / convert::<_, N>(3.0);
                let t = h * h * t / convert::<_, N>(2.0);
                t
            }
            _ => unimplemented!(),
        }
    }
}
/// Naive implementation.
/// Not accurate for small h.
#[inline(always)]
pub fn _cov_xv_1<N: RealField>(h: N, gamma: N) -> (N, N, N) {
    let delta = h * gamma;
    let eta1 = (-delta).exp();
    let eta2 = (convert::<_, N>(-2.0) * delta).exp();

    let two = convert::<_, N>(2.0);

    let cov_xx = (N::one() + two * delta - (two - eta1).powi(2)) / (two * gamma.powi(3));
    let cov_xv = (N::one() - eta1).powi(2) / (two * gamma.powi(2));
    let cov_vv = (N::one() - eta2) / (two * gamma);
    (cov_xx, cov_xv, cov_vv)
}
/// More accurate.
#[inline(always)]
#[allow(non_snake_case)]
pub fn _cov_xv_2<N: RealField>(h: N, gamma: N) -> (N, N, N) {
    let delta = h * gamma;
    let two = convert::<_, N>(2.0);
    if delta > N::ln(two) {
        _cov_xv_1(h, gamma)
    } else {
        let Rm2 = _Rm2(delta);
        let c = -delta - Rm2;

        let t = two - c;
        let tt = t * t;
        let cog = c / gamma;
        let cov_xx = h * (cog * cog + two * Rm2 / gamma / gamma) / tt;
        let cov_xv = two * h * h / tt;
        let cov_vv = two * h * (Rm2 + two) / tt;
        (cov_xx, cov_xv, cov_vv)
    }
}
/// Taylor expansion
#[inline(always)]
pub fn _cov_xv_3<N: RealField>(h: N, gamma: N) -> (N, N, N) {
    let delta = h * gamma;
    if delta < convert::<_, N>(1e-1) {
        let cov_xx = {
            let t = N::one() - delta * convert::<_, N>(7.0 / 15.0);
            let t = convert::<_, N>(1.0 / 3.0) - t * delta / convert::<_, N>(4.0);
            let t = h * h * h * t;
            t
        };
        let cov_xv = {
            let t = N::one() - delta * convert::<_, N>(7.0 / 12.0);
            let t = N::one() - t * delta;
            let t = h * h * t / convert::<_, N>(2.0);
            t
        };
        let cov_vv = {
            let t = N::one() - delta * convert::<_, N>(2.0 / 3.0);
            let t = N::one() - t * delta;
            let t = h * t;
            t
        };
        (cov_xx, cov_xv, cov_vv)
    } else {
        _cov_xv_2(h, gamma)
    }
}
#[inline(always)]
#[allow(dead_code)]
pub fn cov_xv<N: RealField>(h: N, gamma: N) -> OMatrix<N, Const<2>, Const<2>> {
    let (cov_xx, cov_xv, cov_vv) = _cov_xv_3(h, gamma);
    let cov = Matrix2::from([[cov_xx, cov_xv], [cov_xv, cov_vv]]);
    cov
}
/// By singular decomposition.
#[inline(always)]
pub fn sqrt_cov_xv_1<N: RealField>(h: N, gamma: N) -> OMatrix<N, Const<2>, Const<2>> {
    let cov = cov_xv(h, gamma);
    let mut svd = cov.svd(true, true);
    svd.singular_values = svd.singular_values.map(|x| x.sqrt());
    let sqrt_cov = svd.recompose().unwrap();
    sqrt_cov
}
/// Direct sqrt 2x2 matrix
#[inline(always)]
pub fn sqrt_cov_xv_2<N: RealField>(h: N, gamma: N) -> OMatrix<N, Const<2>, Const<2>> {
    let two = convert::<_, N>(2.0);
    let (a, b, c) = _cov_xv_3(h, gamma);

    let s = (a * c - b * b).max(N::zero()).sqrt();
    let t = (a + c + two * s).sqrt();
    if t == N::zero() {
        Matrix2::zeros()
    } else {
        let sqrt_cov = Matrix2::from([[a + s, b], [b, c + s]]) / t;
        sqrt_cov
    }
}
/// Translated Cholesky decomposition
/// decompose v variance first.
#[inline(always)]
pub fn sqrt_cov_xv_3<N: RealField>(h: N, gamma: N) -> OMatrix<N, Const<2>, Const<2>> {
    let (a, b, c) = _cov_xv_3(h, gamma);

    let bp = c.sqrt();
    let ap;
    if c == N::zero() {
        ap = N::zero();
    } else {
        ap = b / bp;
    }
    let cp = (a - ap * ap).max(N::zero()).sqrt();
    // OMatrix assume column-major 2D array by default.
    // https://nalgebra.org/docs/user_guide/vectors_and_matrices
    Matrix2::from([[ap, bp], [cp, N::zero()]])
}

#[cfg(test)]
mod tests {
    #![allow(unused)]
    use approx::assert_relative_eq;
    use nalgebra::*;
    use rand::prelude::*;

    use super::{_cov_xv_1, _cov_xv_2, _cov_xv_3, cov_xv, sqrt_cov_xv_1, sqrt_cov_xv_2, sqrt_cov_xv_3};
    use super::{psi_1, psi_2, psi_3};

    #[test]
    fn test_psi() {
        type DType = f32;
        for h in (-70..20).map(|x| (2.0).powi(x)) {
            for gamma in &[1e-9, 0.1, 2.0, 1e4] {
                for order in (0..=2) {
                    // naive implementation don't generate correct result.
                    if (*gamma > 1e-2 && h > 0.005) {
                        assert_relative_eq!(
                            psi_1::<DType>(h, *gamma, order),
                            psi_2(h, *gamma, order),
                            epsilon = 0.0,
                            max_relative = 0.2
                        );
                    }
                    // 0.1% accuracy is enought.
                    assert_relative_eq!(
                        psi_2::<DType>(h, *gamma, order),
                        psi_3(h, *gamma, order),
                        epsilon = 0.0,
                        max_relative = 0.001
                    );
                }
            }
        }
    }
    #[test]
    fn test_cov_xv() {
        type DType = f32;
        for h in (-70..20).map(|x| (2.0).powi(x)) {
            for gamma in [1e-9, 0.1, 2.0, 1e4].iter().cloned() {
                let r1 = _cov_xv_1::<DType>(h, gamma);
                let r2 = _cov_xv_2::<DType>(h, gamma);
                let r3 = _cov_xv_3::<DType>(h, gamma);
                if (gamma > 1e-2 && h > 0.1) {
                    assert_relative_eq!(r1.0, r2.0, epsilon = 0.0, max_relative = 0.2);
                    assert_relative_eq!(r1.1, r2.1, epsilon = 0.0, max_relative = 0.2);
                    assert_relative_eq!(r1.2, r2.2, epsilon = 0.0, max_relative = 0.2);
                }
                assert_relative_eq!(r2.0, r3.0, epsilon = DType::MIN_POSITIVE * 10.0, max_relative = 1e-3);
                assert_relative_eq!(r2.1, r3.1, epsilon = 0.0, max_relative = 1e-3);
                assert_relative_eq!(r2.2, r3.2, epsilon = 0.0, max_relative = 1e-3);
            }
        }
    }
    #[test]
    fn test_sqrt_cov_xv() {
        type DType = f32;
        for h in (-70..20).map(|x| (2.0).powi(x)) {
            for gamma in [1e-9, 0.1, 2.0, 1e4].iter().cloned() {
                let cov = cov_xv::<DType>(h, gamma);
                let r1 = sqrt_cov_xv_1::<DType>(h, gamma);
                let r2 = sqrt_cov_xv_2::<DType>(h, gamma);
                let r3 = sqrt_cov_xv_3::<DType>(h, gamma);

                assert_relative_eq!(cov, r1 * r1, epsilon = 1e-1, max_relative = 1.0);
                assert_relative_eq!(cov, r2 * r2, epsilon = 0.0, max_relative = 9e-1);
                assert_relative_eq!(cov, r3 * r3.transpose(), epsilon = 0.0, max_relative = 1e-6);
            }
        }
    }
}
