use nalgebra::allocator::Allocator;
use nalgebra::*;
use rand::prelude::*;
use rand_distr::StandardNormal;

use super::basic::{_cov_xv_2, psi_3};
use super::rmm_basic::{_cov_xvax_3, _sqrt_cov_xvax_3};
use super::InfoPerStep;
use super::{ULDParam, ULDProcess, ULDProcessNew};

pub struct LPMLikeRMM<N: RealField> {
    uldp: ULDParam<N>,
    brownian_rng: StdRng,

    h: N,

    psi0: N,
    psi1: N,
    psi2: N,

    cov_xx: N,
    cov_xv: N,
    cov_vv: N,

    a_rng: StdRng,
}
impl<N: RealField> LPMLikeRMM<N> {
    pub fn new2(step: N, uldp: ULDParam<N>, brownian_rng: StdRng, a_rng: StdRng) -> Self {
        let gamma = uldp.gamma;

        let h = uldp.scale_t(step);

        let psi0 = psi_3(h, gamma, 0);
        let psi1 = psi_3(h, gamma, 1);
        let psi2 = psi_3(h, gamma, 2);
        let (cov_xx, cov_xv, cov_vv) = _cov_xv_2(h, gamma);
        Self {
            uldp,
            brownian_rng,
            h,
            psi0,
            psi1,
            psi2,
            cov_xx,
            cov_xv,
            cov_vv,
            a_rng,
        }
    }
}

impl<N> ULDProcessNew<N> for LPMLikeRMM<N>
where
    N: RealField,
{
    fn new(step: N, uldp: ULDParam<N>, rng: StdRng) -> Self {
        Self::new2(step, uldp, rng.clone(), rng)
    }
}
impl<N, D> ULDProcess<OVector<N, D>> for LPMLikeRMM<N>
where
    N: RealField,
    D: Dim,
    DefaultAllocator: Allocator<N, D>,
    DefaultAllocator: Allocator<N, D, Const<3>>,
{
    fn one_step_generic<IB: InfoPerStep<OVector<N, D>>>(
        &mut self,
        go: &mut dyn FnMut(&OVector<N, D>) -> OVector<N, D>,
        x: &OVector<N, D>,
        v: &OVector<N, D>,
    ) -> (OVector<N, D>, OVector<N, D>, IB::Output) {
        let gamma = self.uldp.gamma;
        let xi = self.uldp.xi;
        let temperature = self.uldp.temperature;
        let h = self.h;

        let two = convert::<_, N>(2.0);

        let a = convert::<_, N>(self.a_rng.gen::<_>());
        let (cov_xax, cov_vax, cov_axax) = _cov_xvax_3(h, gamma, a);
        let sqrt_cov = _sqrt_cov_xvax_3(self.cov_xx, self.cov_xv, self.cov_vv, cov_xax, cov_vax, cov_axax).transpose();

        let normal = StandardNormal;
        let d = D::from_usize(x.nrows());
        let gauss = OMatrix::<N, D, Const<3>>::from_vec_generic(
            d,
            Const::<3>,
            normal
                .sample_iter(&mut self.brownian_rng)
                .take(3 * d.value())
                .map(|x| convert(x))
                .collect(),
        );

        let noise = gauss * (sqrt_cov * (two * gamma / xi * temperature).sqrt());

        let ex = noise.column(0);
        let ev = noise.column(1);

        let g = go(x);

        let nx = x + v * self.psi1 - &g * (self.psi2 / xi) + ex;
        let nv = v * self.psi0 - &g * (self.psi1 / xi) + ev;

        let mut ib = IB::default();
        if ib.require_gs() {
            ib.push_g(x.clone(), g);
        }
        if ib.require_noise() {
            ib.push_noise(ex.into_owned(), ev.into_owned());
        }

        (nx, nv, ib.build())
    }
}

#[cfg(test)]
mod tests {
    #![allow(unused)]
    use approx::assert_relative_eq;
    use nalgebra::*;
    use rand::prelude::*;

    use crate::uld::{lpm_like_rmm::LPMLikeRMM, rmm::RMM, ULDParam, ULDProcess, ULDProcessNew};

    #[test]
    fn like_rmm() {
        type DType = f32;
        let uldp = ULDParam::<DType> {
            xi: 100.0,
            temperature: 10.0,
            gamma: 2.0,
        };
        let rng = StdRng::seed_from_u64(1234);
        let step: DType = 0.1;

        let d: usize = 2.0.powi(20) as usize;

        let mut t1 = RMM::new(step, uldp.clone(), rng.clone());
        let mut t2 = LPMLikeRMM::new(step, uldp.clone(), rng.clone());

        let x = DVector::repeat(d, 1.0 as DType);
        let v = DVector::repeat(d, 2.0 as DType);
        let mut go: Box<dyn FnMut(&_) -> _> = Box::new(move |_| DVector::repeat(d, 0 as DType));

        let (_nx1, nv1, allinfo1) = t1.one_step_all(&mut go, &x, &v);
        let (_nx2, nv2, allinfo2) = t2.one_step_all(&mut go, &x, &v);
        assert_relative_eq!(allinfo1.ex, allinfo2.ex);
        assert_relative_eq!(allinfo1.ev, allinfo2.ev);
    }
}
