use nalgebra::allocator::Allocator;
use nalgebra::*;
use rand::prelude::*;
use rand_distr::StandardNormal;

use super::basic::psi_3;
use super::rmm::RMM;
use super::rmm_basic::{_cov_xvax_3, _sqrt_cov_xvax_3};
use super::InfoPerStep;
use super::{ULDParam, ULDProcess};

use std::ops::Add;
pub fn custom_sum<T: Add<T, Output = T>>(mut it: impl Iterator<Item = T>) -> T {
    let first = it.next().unwrap();
    it.fold(first, |a, b| a + b)
}

fn multiple_segments<N, D>(
    segments: usize,
    seg_rng: &mut StdRng,
    rmm: &mut RMM<N>,
    rows: usize,
) -> (N, OVector<N, D>, OVector<N, D>, OVector<N, D>)
where
    N: RealField,
    D: Dim,
    DefaultAllocator: Allocator<N, D>,
    DefaultAllocator: Allocator<N, D, Const<3>>,
    RMM<N>: ULDProcess<OVector<N, D>>,
{
    let gamma = rmm.uldp.gamma;
    let xi = rmm.uldp.xi;
    let temperature = rmm.uldp.temperature;
    let h = rmm.h;

    let all_info: Vec<_> = (0..segments)
        .map(|_| {
            let two = convert::<_, N>(2.0);

            let a = convert::<_, N>(rmm.a_rng.gen::<_>());

            let (cov_xax, cov_vax, cov_axax) = _cov_xvax_3(h, gamma, a);
            let sqrt_cov = _sqrt_cov_xvax_3(rmm.cov_xx, rmm.cov_xv, rmm.cov_vv, cov_xax, cov_vax, cov_axax).transpose();

            let normal = StandardNormal;
            let d = D::from_usize(rows);
            let gauss = OMatrix::<N, D, Const<3>>::from_vec_generic(
                d,
                Const::<3>,
                normal
                    .sample_iter(&mut rmm.brownian_rng)
                    .take(3 * d.value())
                    .map(|x| convert(x))
                    .collect(),
            );

            let noise = gauss * (sqrt_cov * (two * gamma / xi * temperature).sqrt());

            (a, noise)
        })
        .collect();
    let all_ex: OVector<N, D> = custom_sum(all_info.iter().enumerate().map(|(i, (_, noise))| {
        noise.column(0) + noise.column(1) * psi_3(convert::<_, N>((segments - i - 1) as f64) * h, gamma, 1)
    }));
    let all_ev: OVector<N, D> = custom_sum(
        all_info
            .iter()
            .enumerate()
            .map(|(i, (_, noise))| noise.column(1) * psi_3(convert::<_, N>((segments - i - 1) as f64) * h, gamma, 0)),
    );
    let seg: usize = seg_rng.gen_range(0..segments);
    let all_a = convert::<_, N>(seg as f64) + all_info[seg].0;
    let all_eax: OVector<N, D> = if seg > 0 {
        custom_sum(all_info.iter().take(seg).enumerate().map(|(i, (_, noise))| {
            noise.column(0) + noise.column(1) * psi_3((all_a - convert::<_, N>((i + 1) as f64)) * h, gamma, 1)
        })) + all_info[seg].1.column(2)
    } else {
        all_info[seg].1.column(2).into_owned()
    };
    (all_a, all_ex, all_ev, all_eax)
}
pub trait ULDProcessNewMS<N> {
    fn new(segments: usize, step: N, uldp: ULDParam<N>, rng: StdRng) -> Self;
}

/// similar to RMM except the noise terms are calculated by each segment.
/// Useful to ensure coherent with some other RMM instance.
pub struct RMMMS<N> {
    segments: usize,
    seg_rng: StdRng,
    rmm: RMM<N>,

    uldp: ULDParam<N>,

    h: N,
    all_h: N,

    psi0: N,
    psi1: N,
}
impl<N: RealField> RMMMS<N> {
    pub fn new2(
        segments: usize,
        step: N,
        uldp: ULDParam<N>,
        seg_rng: StdRng,
        brownian_rng: StdRng,
        a_rng: StdRng,
    ) -> Self {
        let rmm = RMM::new2(step, uldp.clone(), brownian_rng, a_rng);

        let gamma = uldp.gamma;

        let h = uldp.scale_t(step);
        let all_h = convert::<_, N>(segments as f64) * h;

        let psi0 = psi_3(all_h, gamma, 0);
        let psi1 = psi_3(all_h, gamma, 1);
        Self {
            segments,
            seg_rng,
            rmm,
            uldp,
            h,
            all_h,
            psi0,
            psi1,
        }
    }
}
impl<N: RealField> ULDProcessNewMS<N> for RMMMS<N> {
    fn new(segments: usize, step: N, uldp: ULDParam<N>, rng: StdRng) -> Self {
        Self::new2(segments, step, uldp, rng.clone(), rng.clone(), rng)
    }
}

impl<N, D> ULDProcess<OVector<N, D>> for RMMMS<N>
where
    N: RealField,
    D: Dim,
    DefaultAllocator: Allocator<N, D>,
    DefaultAllocator: Allocator<N, D, Const<3>>,
{
    fn one_step_generic<IB: InfoPerStep<OVector<N, D>>>(
        &mut self,
        go: &mut dyn FnMut(&OVector<N, D>) -> OVector<N, D>,
        x: &OVector<N, D>,
        v: &OVector<N, D>,
    ) -> (OVector<N, D>, OVector<N, D>, IB::Output) {
        let gamma = self.uldp.gamma;
        let xi = self.uldp.xi;
        let h = self.h;
        let all_h = self.all_h;

        let (a, ex, ev, eax) = multiple_segments(self.segments, &mut self.seg_rng, &mut self.rmm, x.nrows());

        let psi0 = self.psi0;
        let psi1 = self.psi1;
        let apsi1 = psi_3(a * h, gamma, 1);
        let apsi2 = psi_3(a * h, gamma, 2);
        let dapsi0 = psi_3((convert::<_, N>(self.segments as f64) - a) * h, gamma, 0);
        let dapsi1 = psi_3((convert::<_, N>(self.segments as f64) - a) * h, gamma, 1);

        let g = go(x);
        let ax = x + v * apsi1 - &g * (apsi2 / xi) + eax;
        let ag = go(&ax);
        let nx = x + v * psi1 - &ag * (all_h * dapsi1 / xi) + &ex;
        let nv = v * psi0 - &ag * (all_h * dapsi0 / xi) + &ev;

        let mut ib = IB::default();
        if ib.require_gs() {
            ib.push_g(x.clone(), g);
            ib.push_g(ax, ag);
        }
        if ib.require_noise() {
            ib.push_noise(ex, ev);
        }

        (nx, nv, ib.build())
    }
}

/// similar to OM except the noise terms are calculated by each segment.
/// Useful to ensure coherent with some other RMM instance.
pub struct OMMS<N> {
    segments: usize,
    seg_rng: StdRng,
    rmm: RMM<N>,

    uldp: ULDParam<N>,

    h: N,
    all_h: N,

    psi0: N,
    psi1: N,
}
impl<N: RealField> OMMS<N> {
    pub fn new2(
        segments: usize,
        step: N,
        uldp: ULDParam<N>,
        seg_rng: StdRng,
        brownian_rng: StdRng,
        a_rng: StdRng,
    ) -> Self {
        let rmm = RMM::new2(step, uldp.clone(), brownian_rng, a_rng);

        let gamma = uldp.gamma;

        let h = uldp.scale_t(step);
        let all_h = convert::<_, N>(segments as f64) * h;

        let psi0 = psi_3(all_h, gamma, 0);
        let psi1 = psi_3(all_h, gamma, 1);
        Self {
            segments,
            seg_rng,
            rmm,
            uldp,
            h,
            all_h,
            psi0,
            psi1,
        }
    }
}
impl<N: RealField> ULDProcessNewMS<N> for OMMS<N> {
    fn new(segments: usize, step: N, uldp: ULDParam<N>, rng: StdRng) -> Self {
        Self::new2(segments, step, uldp, rng.clone(), rng.clone(), rng)
    }
}

impl<N, D> ULDProcess<OVector<N, D>> for OMMS<N>
where
    N: RealField,
    D: Dim,
    DefaultAllocator: Allocator<N, D>,
    DefaultAllocator: Allocator<N, D, Const<3>>,
{
    fn one_step_generic<IB: InfoPerStep<OVector<N, D>>>(
        &mut self,
        go: &mut dyn FnMut(&OVector<N, D>) -> OVector<N, D>,
        x: &OVector<N, D>,
        v: &OVector<N, D>,
    ) -> (OVector<N, D>, OVector<N, D>, IB::Output) {
        let gamma = self.uldp.gamma;
        let xi = self.uldp.xi;
        let h = self.h;
        let all_h = self.all_h;

        let (a, ex, ev, eax) = multiple_segments(self.segments, &mut self.seg_rng, &mut self.rmm, x.nrows());

        let psi0 = self.psi0;
        let psi1 = self.psi1;
        let apsi1 = psi_3(a * h, gamma, 1);
        let dapsi0 = psi_3((convert::<_, N>(self.segments as f64) - a) * h, gamma, 0);
        let dapsi1 = psi_3((convert::<_, N>(self.segments as f64) - a) * h, gamma, 1);

        let ax = x + v * apsi1 + eax;
        let ag = go(&ax);
        let nx = x + v * psi1 - &ag * (all_h * dapsi1 / xi) + &ex;
        let nv = v * psi0 - &ag * (all_h * dapsi0 / xi) + &ev;

        let mut ib = IB::default();
        if ib.require_gs() {
            ib.push_g(x.clone(), go(x));
            ib.push_g(ax, ag);
        }
        if ib.require_noise() {
            ib.push_noise(ex, ev);
        }

        (nx, nv, ib.build())
    }
}

/// similar to OM except the noise terms are calculated by each segment.
/// Useful to ensure coherent with some other RMM instance.
pub struct LPMLikeRMMMS<N> {
    segments: usize,
    seg_rng: StdRng,
    rmm: RMM<N>,

    uldp: ULDParam<N>,

    psi0: N,
    psi1: N,
    psi2: N,
}
impl<N: RealField> LPMLikeRMMMS<N> {
    pub fn new2(
        segments: usize,
        step: N,
        uldp: ULDParam<N>,
        seg_rng: StdRng,
        brownian_rng: StdRng,
        a_rng: StdRng,
    ) -> Self {
        let rmm = RMM::new2(step, uldp.clone(), brownian_rng, a_rng);

        let gamma = uldp.gamma;

        let h = uldp.scale_t(step);
        let all_h = convert::<_, N>(segments as f64) * h;

        let psi0 = psi_3(all_h, gamma, 0);
        let psi1 = psi_3(all_h, gamma, 1);
        let psi2 = psi_3(all_h, gamma, 2);
        Self {
            segments,
            seg_rng,
            rmm,
            uldp,
            psi0,
            psi1,
            psi2,
        }
    }
}
impl<N: RealField> ULDProcessNewMS<N> for LPMLikeRMMMS<N> {
    fn new(segments: usize, step: N, uldp: ULDParam<N>, rng: StdRng) -> Self {
        Self::new2(segments, step, uldp, rng.clone(), rng.clone(), rng)
    }
}

impl<N, D> ULDProcess<OVector<N, D>> for LPMLikeRMMMS<N>
where
    N: RealField,
    D: Dim,
    DefaultAllocator: Allocator<N, D>,
    DefaultAllocator: Allocator<N, D, Const<3>>,
{
    /// nx,nv,g(old_x),ex,ev
    fn one_step_generic<IB: InfoPerStep<OVector<N, D>>>(
        &mut self,
        go: &mut dyn FnMut(&OVector<N, D>) -> OVector<N, D>,
        x: &OVector<N, D>,
        v: &OVector<N, D>,
    ) -> (OVector<N, D>, OVector<N, D>, IB::Output) {
        let xi = self.uldp.xi;

        let (_a, ex, ev, _eax) = multiple_segments(self.segments, &mut self.seg_rng, &mut self.rmm, x.nrows());

        let psi0 = self.psi0;
        let psi1 = self.psi1;
        let psi2 = self.psi2;

        let g = go(x);

        let nx = x + v * psi1 - &g * (psi2 / xi) + &ex;
        let nv = v * psi0 - &g * (psi1 / xi) + &ev;

        let mut ib = IB::default();
        if ib.require_gs() {
            ib.push_g(x.clone(), g);
        }
        if ib.require_noise() {
            ib.push_noise(ex, ev);
        }

        (nx, nv, ib.build())
    }
}

#[cfg(test)]
mod tests {
    #![allow(unused)]
    use approx::assert_relative_eq;
    use nalgebra::*;
    use rand::prelude::*;

    use super::RMMMS;
    use crate::uld::basic::cov_xv;
    use crate::uld::rmm::RMM;
    use crate::uld::AllInfo;
    use crate::uld::{ULDParam, ULDProcess, ULDProcessNew};

    #[test]
    fn rmm_ms() {
        type DType = f32;
        let d: usize = 2.0.powi(20) as usize;
        let uldp = ULDParam::<DType> {
            xi: 100.0,
            temperature: 10.0,
            gamma: 2.0,
        };
        let rng = StdRng::seed_from_u64(1234);
        let step: DType = 0.1;

        let segments = 10;

        let mut t = RMMMS::new2(
            segments,
            step / (segments as DType),
            uldp.clone(),
            rng.clone(),
            rng.clone(),
            rng.clone(),
        );

        let x = DVector::repeat(d, 1.0 as DType);
        let v = DVector::repeat(d, 2.0 as DType);
        let mut go: Box<dyn FnMut(&_) -> _> = Box::new(move |_| DVector::repeat(d, 3 as DType));

        let (_, _, allinfo) = t.one_step_all(&mut go, &x, &v);
        let AllInfo { ex, ev, .. } = allinfo;
        let em_cov_xx = ex.dot(&ex) / d as DType;
        let em_cov_xv = ex.dot(&ev) / d as DType;
        let em_cov_vv = ev.dot(&ev) / d as DType;
        let em_cov = Matrix2::from([[em_cov_xx, em_cov_xv], [em_cov_xv, em_cov_vv]]);

        let cov = cov_xv::<DType>(uldp.scale_t(step), uldp.gamma) * 2.0 * uldp.gamma / uldp.xi * uldp.temperature;
        assert_relative_eq!(em_cov, cov, epsilon = 0.0, max_relative = 5e-3);
    }
    #[test]
    fn rmm_ms_same_rmm() {
        // fails for f32. Truncation error matters here.
        // type DType = f32;
        type DType = f64;
        let d: usize = 2.0.powi(20) as usize;
        let d: usize = 2.0.powi(5) as usize;
        let uldp = ULDParam::<DType> {
            xi: 100.0,
            temperature: 10.0,
            gamma: 2.0,
        };
        let rng = StdRng::seed_from_u64(1234);
        let step: DType = 0.1;

        let segments = 10;
        let steps = 100000;

        let mut t = RMMMS::new2(
            segments,
            step / (segments as DType),
            uldp.clone(),
            rng.clone(),
            rng.clone(),
            rng.clone(),
        );
        let mut t2 = RMM::new2(step / (segments as DType), uldp.clone(), rng.clone(), rng.clone());

        let mut x = DVector::repeat(d, 1.0 as DType);
        let mut v = DVector::repeat(d, 2.0 as DType);
        let mut x2 = x.clone();
        let mut v2 = v.clone();
        let mut go: Box<dyn FnMut(&_) -> _> = Box::new(move |_| DVector::repeat(d, 0 as DType));

        for _ in (0..steps) {
            let (nx, nv, _) = t.one_step_all(&mut go, &x, &v);
            x = nx;
            v = nv;
        }
        for _ in (0..steps * segments) {
            let (nx2, nv2, _) = t2.one_step_all(&mut go, &x2, &v2);
            x2 = nx2;
            v2 = nv2;
        }

        assert_relative_eq!(x, x2, epsilon = 0.0, max_relative = 5e-9);
        assert_relative_eq!(v, v2, epsilon = 0.0, max_relative = 5e-9);
    }
}
