use nalgebra::*;
use rand::prelude::*;
use std::iter::Sum;
use std::ops::{AddAssign, DivAssign};

use super::GradientEstimator;
use crate::model::WithProxyGradient;

#[derive(Clone, Default)]
pub struct SGD(usize);

impl<S, M> GradientEstimator<S, M> for SGD
where
    M::DType: RealField,
    M: WithProxyGradient<State = S>,
    S: Sum<S> + AddAssign<S> + DivAssign<M::DType>,
{
    fn new(_initial_x: &M::State, _model: &M, _rng: StdRng) -> Self {
        SGD(0)
    }
    fn gradient_query_num(&self) -> usize {
        self.0
    }
    fn gradient_overhead(&self, _model: &M) -> f32 {
        1.0
    }
    fn name(&self) -> String {
        "SGD".into()
    }
    fn step_gradient(&mut self, model: &M, x: &M::State, data_points: &[&M::DataType], _is: &[usize]) -> M::State {
        let mut gradient: S = data_points.iter().map(|d| model.gradient_i_scaled(x, d)).sum();
        gradient /= convert::<_, M::DType>(data_points.len() as f64);
        if let Some(g0) = model.gradient_0(x) {
            gradient += g0;
        }
        self.0 += data_points.len();
        gradient
    }
}
