use nalgebra::*;
use rand::prelude::*;
use std::iter::Sum;
use std::ops::{Add, AddAssign, Div, DivAssign, Sub, SubAssign};

use super::GradientEstimator;
use crate::model::WithProxyGradient;

/// deterministic
#[derive(Clone)]
pub struct SVRG2<S> {
    old_x: S,
    old_gradient: S,
    gradient_query_num: usize,
    count: usize,
    rng: StdRng,
}

impl<S> SVRG2<S>
where
    S: Clone,
{
    fn update_full_gradient<M>(&mut self, model: &M, new_x: &S)
    where
        M: WithProxyGradient<State = S>,
    {
        let gradient = model.all_gradient_except_0(new_x);
        self.old_x = new_x.clone();
        self.old_gradient = gradient;
        self.gradient_query_num += model.N();
        self.count -= model.N();
    }
    fn check_need_update(&mut self, sample_size: usize, _batch_size: usize) -> bool {
        self.count >= sample_size
    }
}

impl<S, M> GradientEstimator<S, M> for SVRG2<S>
where
    M::DType: RealField,
    M: WithProxyGradient<State = S>,
    S: Clone,
    S: Sum<S> + SubAssign<S> + AddAssign<S> + DivAssign<M::DType>,
    S: Div<M::DType, Output = S> + Sub<S, Output = S> + Add<S, Output = S>,
    for<'a> S: AddAssign<&'a S> + Add<&'a S, Output = S>,
    for<'a> &'a S: Div<M::DType, Output = S>,
{
    fn new(initial_x: &M::State, model: &M, rng: StdRng) -> Self {
        let gradient = model.all_gradient_except_0(initial_x);
        SVRG2 {
            old_x: initial_x.clone(),
            old_gradient: gradient,
            gradient_query_num: model.N(),
            count: 0,
            rng,
        }
    }
    fn gradient_query_num(&self) -> usize {
        self.gradient_query_num
    }
    fn gradient_overhead(&self, _model: &M) -> f32 {
        3.0
    }
    fn name(&self) -> String {
        "SVRG".into()
    }
    fn step_gradient(&mut self, model: &M, x: &M::State, data_points: &[&M::DataType], _is: &[usize]) -> M::State {
        if self.check_need_update(model.N(), data_points.len()) {
            self.update_full_gradient(model, x);
            if let Some(g0) = model.gradient_0(x) {
                return self.old_gradient.clone() + g0;
            }
            return self.old_gradient.clone();
        }

        let gradient_diff: S = data_points
            .iter()
            .map(|d| model.gradient_i_scaled(x, d) - model.gradient_i_scaled(&self.old_x, d))
            .sum::<S>()
            / convert::<_, M::DType>(data_points.len() as f64);
        let mut vr_gradient = gradient_diff + &self.old_gradient;
        self.gradient_query_num += 2 * data_points.len();
        self.count += 2 * data_points.len();

        if let Some(g0) = model.gradient_0(x) {
            vr_gradient += g0;
        }
        vr_gradient
    }
}
