pub mod gaussian_model;
pub mod logistic;

/// Loss function f. Zeroth order oracle.
pub trait ZerothOrderModel {
    type DType;
    type State;
    /// f
    fn all_loss(&self, state: &Self::State) -> Self::DType;
}

impl<T: SumDecomposableModel> ZerothOrderModel for T {
    type DType = <T as SumDecomposableModel>::DType;
    type State = <T as SumDecomposableModel>::State;
    fn all_loss(&self, state: &Self::State) -> Self::DType {
        <T as SumDecomposableModel>::all_loss(self, state)
    }
}

/// Loss function f = \sum_i f_i
/// Don't assume any structure on variable except Sum.
pub trait SumDecomposableModel {
    type DType;
    type DataType;
    type State;

    /// N
    #[allow(non_snake_case)]
    fn N(&self) -> usize;
    /// Datas
    fn all_data(&self) -> &[Self::DataType];
    /// f_i, i>0
    fn loss_i(&self, state: &Self::State, one_data: &Self::DataType) -> Self::DType;
    /// f_i*N
    fn loss_i_scaled(&self, state: &Self::State, one_data: &Self::DataType) -> Self::DType;
    /// f_0
    fn loss_0(&self, state: &Self::State) -> Option<Self::DType>;
    /// f
    fn all_loss(&self, state: &Self::State) -> Self::DType;
    /// \nabla f_i
    fn gradient_i(&self, state: &Self::State, one_data: &Self::DataType) -> Self::State;
    /// \nabla f_i*N
    fn gradient_i_scaled(&self, state: &Self::State, one_data: &Self::DataType) -> Self::State;
    /// \nabla f_0
    fn gradient_0(&self, state: &Self::State) -> Option<Self::State>;
    /// \nabla f
    fn all_gradient(&self, state: &Self::State) -> Self::State;
    /// \nabla \sum_i f_i
    fn all_gradient_except_0(&self, state: &Self::State) -> Self::State;
}

pub trait WithProxyGradient: SumDecomposableModel {
    type GradientProxy: Clone;
    /// Get proxy gradient
    fn proxy_gradient_i(&self, state: &Self::State, one_data: &Self::DataType) -> Self::GradientProxy;
    /// Transform proxy gradient to real gradient: \nabla f_i
    fn to_real_gradient(&self, proxy: &Self::GradientProxy, one_data: &Self::DataType) -> Self::State;
    /// Transform proxy gradient to \nabla f_i*N
    fn to_real_gradient_scaled(&self, proxy: &Self::GradientProxy, one_data: &Self::DataType) -> Self::State;
}
