use nalgebra::allocator::Allocator;
use nalgebra::*;
use rand::prelude::*;
use rand_distr::StandardNormal;

use super::basic::{_cov_xv_2, psi_3};
use super::rmm_basic::{_cov_xvax_3, _sqrt_cov_xvax_3};
use super::InfoPerStep;
use super::{ULDParam, ULDProcess, ULDProcessNew};

pub struct RMM<N> {
    pub uldp: ULDParam<N>,
    pub brownian_rng: StdRng,

    pub h: N,

    psi0: N,
    psi1: N,

    pub cov_xx: N,
    pub cov_xv: N,
    pub cov_vv: N,

    pub a_rng: StdRng,
}
impl<N: RealField> RMM<N> {
    pub fn new2(step: N, uldp: ULDParam<N>, brownian_rng: StdRng, a_rng: StdRng) -> Self {
        let gamma = uldp.gamma;

        let h = uldp.scale_t(step);

        let psi0 = psi_3(h, gamma, 0);
        let psi1 = psi_3(h, gamma, 1);
        let (cov_xx, cov_xv, cov_vv) = _cov_xv_2(h, gamma);
        Self {
            uldp,
            brownian_rng,
            h,
            psi0,
            psi1,
            cov_xx,
            cov_xv,
            cov_vv,
            a_rng,
        }
    }
}

impl<N> ULDProcessNew<N> for RMM<N>
where
    N: RealField,
{
    fn new(step: N, uldp: ULDParam<N>, rng: StdRng) -> Self {
        Self::new2(step, uldp, rng.clone(), rng)
    }
}
impl<N, D> ULDProcess<OVector<N, D>> for RMM<N>
where
    N: RealField,
    D: Dim,
    DefaultAllocator: Allocator<N, D>,
    DefaultAllocator: Allocator<N, D, Const<3>>,
{
    fn one_step_generic<IB: InfoPerStep<OVector<N, D>>>(
        &mut self,
        go: &mut dyn FnMut(&OVector<N, D>) -> OVector<N, D>,
        x: &OVector<N, D>,
        v: &OVector<N, D>,
    ) -> (OVector<N, D>, OVector<N, D>, IB::Output) {
        let gamma = self.uldp.gamma;
        let xi = self.uldp.xi;
        let temperature = self.uldp.temperature;
        let h = self.h;

        let two = convert::<_, N>(2.0);

        let a = convert::<_, N>(self.a_rng.gen::<_>());
        let (cov_xax, cov_vax, cov_axax) = _cov_xvax_3(h, gamma, a);
        let sqrt_cov = _sqrt_cov_xvax_3(self.cov_xx, self.cov_xv, self.cov_vv, cov_xax, cov_vax, cov_axax).transpose();

        let normal = StandardNormal;
        let d = D::from_usize(x.nrows());
        let gauss = OMatrix::<N, D, Const<3>>::from_vec_generic(
            d,
            Const::<3>,
            normal
                .sample_iter(&mut self.brownian_rng)
                .take(3 * d.value())
                .map(|x| convert(x))
                .collect(),
        );

        let noise = gauss * (sqrt_cov * (two * gamma / xi * temperature).sqrt());

        let ex = noise.column(0);
        let ev = noise.column(1);
        let eax = noise.column(2);

        let psi0 = self.psi0;
        let psi1 = self.psi1;
        let apsi1 = psi_3(a * h, gamma, 1);
        let apsi2 = psi_3(a * h, gamma, 2);
        let dapsi0 = psi_3((N::one() - a) * h, gamma, 0);
        let dapsi1 = psi_3((N::one() - a) * h, gamma, 1);

        let g = go(x);
        let ax = x + v * apsi1 - &g * (apsi2 / xi) + eax;
        let ag = go(&ax);
        let nx = x + v * psi1 - &ag * (h * dapsi1 / xi) + ex;
        let nv = v * psi0 - &ag * (h * dapsi0 / xi) + ev;

        let mut ib = IB::default();
        if ib.require_gs() {
            ib.push_g(x.clone(), g);
            ib.push_g(ax, ag);
        }
        if ib.require_noise() {
            ib.push_noise(ex.into_owned(), ev.into_owned());
        }

        (nx, nv, ib.build())
    }
}

#[cfg(test)]
mod tests {
    #![allow(unused)]
    use approx::assert_relative_eq;
    use nalgebra::*;
    use rand::prelude::*;

    use super::RMM;
    use crate::uld::basic::cov_xv;
    use crate::uld::AllInfo;
    use crate::uld::{ULDParam, ULDProcess, ULDProcessNew};

    #[test]
    fn rmm() {
        type DType = f32;
        let d: usize = 2.0.powi(20) as usize;
        let uldp = ULDParam::<DType> {
            xi: 100.0,
            temperature: 10.0,
            gamma: 2.0,
        };
        let rng = StdRng::seed_from_u64(1234);
        let step: DType = 0.1;

        let mut t = RMM::new(step, uldp.clone(), rng.clone());

        let x = DVector::repeat(d, 1.0 as DType);
        let v = DVector::repeat(d, 2.0 as DType);
        let mut go: Box<dyn FnMut(&_) -> _> = Box::new(move |_| DVector::repeat(d, 1 as DType));

        let (_, _, allinfo) = t.one_step_all(&mut go, &x, &v);
        let AllInfo { ex, ev, .. } = allinfo;
        let em_cov_xx = ex.dot(&ex) / d as DType;
        let em_cov_xv = ex.dot(&ev) / d as DType;
        let em_cov_vv = ev.dot(&ev) / d as DType;
        let em_cov = Matrix2::from([[em_cov_xx, em_cov_xv], [em_cov_xv, em_cov_vv]]);

        let cov = cov_xv::<DType>(uldp.scale_t(step), uldp.gamma) * 2.0 * uldp.gamma / uldp.xi * uldp.temperature;
        assert_relative_eq!(em_cov, cov, epsilon = 0.0, max_relative = 5e-3);
    }
}
