use rand::prelude::*;

use crate::model::WithProxyGradient;
use crate::tool::simple_type_name;

pub mod gd;
pub mod saga;
pub mod sgd;
pub mod svrg2;

pub trait GradientEstimator<S, M>
where
    M: WithProxyGradient<State = S>,
{
    /// Initialize the algorithm with start point and data.
    /// Some algorithms (mostly variance reductions) need to calculate full gradient at first step.
    fn new(initial_x: &M::State, model: &M, rng: StdRng) -> Self
    where
        Self: Sized;
    /// Set recommended batch size. SARGE need that on initialization.
    fn set_batch_size(&mut self, _batch: usize, _model: &M) {}
    /// How many gradient \nabla f_i have been computed.
    fn gradient_query_num(&self) -> usize;
    /// How many gradient query is needed for one iteration in average divide by batch size.
    fn gradient_overhead(&self, model: &M) -> f32;
    fn name(&self) -> String;
    /// One step of gradient estimation.
    fn step_gradient(&mut self, model: &M, x: &M::State, data_points: &[&M::DataType], is: &[usize]) -> M::State;
}

pub struct FixedBatchGradientOracle<'a, M, S, GE: ?Sized>
where
    M: WithProxyGradient<State = S>,
{
    pub model: &'a M,
    pub batch_rng: StdRng,
    pub batch_size: usize,
    /// only introduced to avoid repeated allocation.
    is: Vec<usize>,
    /// only introduced to avoid repeated allocation.
    ds: Vec<&'a M::DataType>,
    pub gradient_estimator: GE,
}
impl<'a, M, S, GE: ?Sized> FixedBatchGradientOracle<'a, M, S, GE>
where
    M: WithProxyGradient<State = S>,
    GE: GradientEstimator<S, M>,
{
    pub fn call(&mut self, x: &S) -> S {
        if simple_type_name::<GE>() != "GD" {
            Self::update_batch(
                &mut self.is,
                &mut self.ds,
                self.batch_size,
                &mut self.batch_rng,
                self.model,
            );
        }
        self.gradient_estimator.step_gradient(self.model, x, &self.ds, &self.is)
    }
    pub fn update_batch(
        is: &mut Vec<usize>,
        ds: &mut Vec<&'a M::DataType>,
        batch_size: usize,
        batch_rng: &mut StdRng,
        model: &'a M,
    ) {
        is.clear();
        ds.clear();
        is.extend((0..model.N()).choose_multiple(batch_rng, batch_size));
        ds.extend(is.iter().map(|i| &model.all_data()[*i]));
    }
    pub fn new(model: &'a M, batch_rng: StdRng, batch_size: usize, gradient_estimator: GE) -> Self
    where
        GE: Sized,
    {
        Self {
            model,
            batch_rng,
            batch_size,
            is: Vec::new(),
            ds: Vec::new(),
            gradient_estimator,
        }
    }
}
