use nalgebra::*;
use rand::prelude::*;

use crate::gradient::{FixedBatchGradientOracle, GradientEstimator};
use crate::model::WithProxyGradient;

pub mod info;
use info::{AllInfo, AllInfoBuilder, BasicInfoBuilder, InfoPerStep};
pub mod basic;
pub mod lpm_like_rmm;
pub mod multiple;
pub mod optimal;
pub mod rmm;
pub mod rmm_basic;

#[derive(Debug, Clone)]
pub struct ULDParam<N> {
    pub temperature: N,
    pub xi: N,
    pub gamma: N,
}

impl<N: RealField> ULDParam<N> {
    pub fn scale_t(&self, t: N) -> N {
        t * self.xi
    }
}

pub trait ULDProcess<State> {
    /// nx,nv,other
    fn one_step_generic<IB: InfoPerStep<State>>(
        &mut self,
        go: &mut dyn FnMut(&State) -> State,
        x: &State,
        v: &State,
    ) -> (State, State, IB::Output);

    /// nx,nv
    fn one_step(&mut self, go: &mut dyn FnMut(&State) -> State, x: &State, v: &State) -> (State, State) {
        let (nx, nv, _) = self.one_step_generic::<BasicInfoBuilder<_>>(go, x, v);
        (nx, nv)
    }
    /// nx,nv,all other
    fn one_step_all(
        &mut self,
        go: &mut dyn FnMut(&State) -> State,
        x: &State,
        v: &State,
    ) -> (State, State, AllInfo<State>) {
        self.one_step_generic::<AllInfoBuilder<_>>(go, x, v)
    }
}
pub trait ULDProcessNew<N> {
    fn new(step: N, uldp: ULDParam<N>, rng: StdRng) -> Self;
}

use std::marker::PhantomData;
pub struct FixedBatchSeq<'a, M, S, ULD, IB, GE: ?Sized>
where
    M: WithProxyGradient<State = S>,
    ULD: ULDProcess<S>,
{
    pub x: S,
    pub v: S,
    pub uldprocess: ULD,
    pub phantom: PhantomData<IB>,
    pub go: FixedBatchGradientOracle<'a, M, S, GE>,
}
impl<'a, M, S, ULD, IB, GE: ?Sized> Iterator for FixedBatchSeq<'a, M, S, ULD, IB, GE>
where
    M: WithProxyGradient<State = S>,
    ULD: ULDProcess<S>,
    S: Clone,
    GE: GradientEstimator<S, M>,
    IB: InfoPerStep<S>,
{
    /// nx,nv,other,num
    type Item = (S, S, IB::Output, usize);
    fn next(&mut self) -> Option<Self::Item> {
        let tmp = |uldprocess: &mut ULD, go: &mut FixedBatchGradientOracle<'a, M, S, GE>, x, v| {
            uldprocess.one_step_generic::<IB>(&mut |x| go.call(x), x, v)
        };
        let (nx, nv, other) = tmp(&mut self.uldprocess, &mut self.go, &self.x, &self.v);
        let _old_x = std::mem::replace(&mut self.x, nx);
        let _old_v = std::mem::replace(&mut self.v, nv);
        Some((
            self.x.clone(),
            self.v.clone(),
            other,
            self.go.gradient_estimator.gradient_query_num(),
        ))
    }
}
