use rand::prelude::*;

use super::GradientEstimator;
use crate::model::WithProxyGradient;

#[derive(Clone, Default)]
/// gradient evaluation number, batch size
pub struct GD(usize, usize);

impl<S, M> GradientEstimator<S, M> for GD
where
    M: WithProxyGradient<State = S>,
{
    fn new(_initial_x: &M::State, _model: &M, _rng: StdRng) -> Self {
        GD(0, 1)
    }
    fn set_batch_size(&mut self, batch: usize, _model: &M) {
        self.1 = batch
    }
    fn gradient_query_num(&self) -> usize {
        self.0
    }
    fn gradient_overhead(&self, model: &M) -> f32 {
        (model.N() as f32) / (self.1 as f32)
    }
    fn name(&self) -> String {
        "GD".into()
    }
    fn step_gradient(&mut self, model: &M, x: &M::State, _data_points: &[&M::DataType], _is: &[usize]) -> M::State {
        let gradient = model.all_gradient(x);
        self.0 += model.N();

        gradient
    }
}
