use nalgebra::*;
use rand::prelude::*;
use std::iter::Sum;
use std::ops::{AddAssign, Div, SubAssign};

use super::GradientEstimator;
use crate::model::WithProxyGradient;

#[derive(Clone)]
pub struct SAGA<S, M>
where
    M: WithProxyGradient<State = S>,
{
    proxy_gradients: Vec<M::GradientProxy>,
    sum_gradient: S,
    gradient_query_num: usize,
}

impl<S, M> GradientEstimator<S, M> for SAGA<S, M>
where
    M::DType: RealField,
    M: WithProxyGradient<State = S>,
    S: Sum<S> + SubAssign<S> + AddAssign<S> + Div<M::DType, Output = S>,
    for<'a> S: AddAssign<&'a S>,
    for<'a> &'a S: Div<M::DType, Output = S>,
{
    fn new(initial_x: &M::State, model: &M, _rng: StdRng) -> Self {
        let proxy_gradients = model
            .all_data()
            .iter()
            .map(|d| model.proxy_gradient_i(initial_x, d))
            .collect::<Vec<_>>();
        let sum_gradient = proxy_gradients
            .iter()
            .zip(model.all_data().iter())
            .map(|(p, d)| model.to_real_gradient(p, d))
            .sum();
        let gradient_query_num = proxy_gradients.len();
        SAGA {
            proxy_gradients,
            sum_gradient,
            gradient_query_num,
        }
    }
    fn gradient_query_num(&self) -> usize {
        self.gradient_query_num
    }
    fn gradient_overhead(&self, _model: &M) -> f32 {
        1.0
    }
    fn name(&self) -> String {
        "SAGA".into()
    }
    fn step_gradient(&mut self, model: &M, x: &M::State, data_points: &[&M::DataType], is: &[usize]) -> M::State {
        let new_pgs = data_points
            .iter()
            .map(|d| model.proxy_gradient_i(x, d))
            .collect::<Vec<_>>();
        let new_gs = new_pgs
            .iter()
            .zip(data_points.iter())
            .map(|(p, d)| model.to_real_gradient_scaled(p, d))
            .collect::<Vec<_>>();
        let diff: S = is
            .iter()
            .zip(new_gs.into_iter().zip(data_points.iter()))
            .map(|(i, (mut g, d))| {
                g -= model.to_real_gradient_scaled(&self.proxy_gradients[*i], d);
                g
            })
            .sum();
        let mut vr_gradient = &diff / convert::<_, M::DType>(data_points.len() as f64);
        vr_gradient += &self.sum_gradient;
        self.sum_gradient += diff / convert::<_, M::DType>(model.all_data().len() as f64);
        for (i, p) in is.iter().zip(new_pgs.into_iter()) {
            self.proxy_gradients[*i] = p
        }
        self.gradient_query_num += data_points.len();

        if let Some(g0) = model.gradient_0(x) {
            vr_gradient += g0;
        }
        vr_gradient
    }
}
