#![allow(dead_code)]
use std::iter::Sum;

use rand::prelude::*;

use super::{SumDecomposableModel, WithProxyGradient};
use nalgebra::allocator::Allocator;
use nalgebra::*;

#[allow(non_snake_case)]
#[derive(Debug)]
pub struct Logistic<N, D>
where
    N: RealField,
    D: Dim,
    DefaultAllocator: Allocator<N, D, Dynamic>,
{
    data_ind: Vec<usize>,
    data_size: usize,
    pub m: N,
    a_s: OMatrix<N, D, Dynamic>,
    ys: OVector<N, Dynamic>,
    a_dim: usize,
    yas: OMatrix<N, D, Dynamic>,
    yas_trans: OMatrix<N, Dynamic, D>,
}
impl<N, D> Logistic<N, D>
where
    N: RealField,
    D: Dim,
    DefaultAllocator: Allocator<N, D, Dynamic>,
{
    /// initialize logistic model.
    pub fn init(a_s: OMatrix<N, D, Dynamic>, ys: OVector<N, Dynamic>, m: N) -> Self {
        let (a_dim, data_size) = a_s.shape();
        let data_ind = (0..data_size).into_iter().collect();
        // let yas = a_s.clone_owned();
        let mut yas = a_s.clone();
        yas.column_iter_mut().zip(ys.iter()).for_each(|(mut a, y)| a *= *y);
        let yas_trans = yas.transpose().into_owned();
        Self {
            data_ind,
            data_size,
            m,
            a_s,
            ys,
            a_dim,
            yas,
            yas_trans,
        }
    }
}
impl<N, D> Logistic<N, D>
where
    N: RealField,
    D: Dim,
    DefaultAllocator: Allocator<N, D, Dynamic>,
    DefaultAllocator: Allocator<N, D>,
{
    /// initialize state variable
    pub fn init_state(&self, _rng: &mut StdRng) -> OVector<N, D> {
        &self.a_s.column(0) * N::zero()
        // self.yas.column_mean()
    }
}
impl<N, D> Logistic<N, D>
where
    N: RealField,
    D: Dim,
    DefaultAllocator: Allocator<N, D, Dynamic>,
    DefaultAllocator: Allocator<N, D>,
{
    /// [yi ai x]_i
    fn yaxs(&self, x: &OVector<N, D>) -> OVector<N, Dynamic> {
        // (self.a_s.transpose() * x).component_mul(&self.ys)
        &self.yas_trans * x
    }

    /// [yi ai x_j]_i_j
    fn yaxss(&self, x: &OMatrix<N, D, Dynamic>) -> OMatrix<N, Dynamic, Dynamic> {
        &self.yas_trans * x
    }

    /// exp(yi ai x)
    fn exp_yax1(&self, x: &OVector<N, D>, i: usize) -> N {
        self.yas.column(i).dot(&x).exp()
    }

    /// [exp(yi ai x)]_i
    fn exp_yaxs(&self, x: &OVector<N, D>) -> OVector<N, Dynamic> {
        self.yaxs(x).map(|x| x.exp())
    }

    /// ln(1+exp(x))
    fn soft_plus(x: N) -> N {
        if x >= N::zero() {
            ((-x).exp() + N::one()).ln() + x
        } else {
            ((x).exp() + N::one()).ln()
        }
    }

    /// log(1+exp(-yi ai x))
    fn soft_plus_n1(&self, x: &OVector<N, D>, i: usize) -> N {
        let mul = -self.yas.column(i).dot(&x);
        Self::soft_plus(mul)
    }

    /// [log(1+exp(-yi ai x))]_i
    fn soft_plus_ns(&self, x: &OVector<N, D>) -> OVector<N, Dynamic> {
        let muls = -self.yaxs(x);
        muls.map(|mul| Self::soft_plus(mul))
    }

    /// Predict probability on ground truth of one set of weights.
    pub fn probs(&self, x: &OVector<N, D>) -> DVector<N> {
        self.yaxs(x).map(|x| N::one() / (N::one() + (-x).exp()))
    }

    /// Predict probability on ground truth of an ensemble for each sample.
    pub fn probs_ensemble(&self, x: &OMatrix<N, D, Dynamic>) -> DVector<N> {
        self.yaxss(x).map(|x| N::one() / (N::one() + (-x).exp())).column_mean()
    }
}

impl<N, D> Logistic<N, D>
where
    N: RealField,
    D: Dim,
    DefaultAllocator: Allocator<N, D, Dynamic>,
    DefaultAllocator: Allocator<N, D, D>,
    DefaultAllocator: Allocator<N, D>,
{
    #[allow(non_snake_case)]
    pub fn calculate_L(&self) -> N {
        let cov = &self.a_s * &self.a_s.transpose();

        let mut vec = self.a_s.column_sum();
        // Power iteration.
        for _ in 0..100 * self.a_dim {
            vec = &cov * vec;
            let norm = vec.norm();
            vec = vec / norm;
        }
        let spectral_norm = vec.dot(&(&cov * &vec));

        spectral_norm / convert::<_, N>(4.0) + self.m
    }
}

impl<N, D> SumDecomposableModel for Logistic<N, D>
where
    N: RealField,
    D: Dim,
    DefaultAllocator: Allocator<N, D, Dynamic>,
    N: Sum<N>,
    DefaultAllocator: Allocator<N, D>,
{
    type DType = N;
    type DataType = usize;
    type State = OVector<N, D>;

    fn N(&self) -> usize {
        self.data_size
    }
    fn all_data(&self) -> &[Self::DataType] {
        &self.data_ind
    }
    fn loss_i(&self, state: &Self::State, one_data: &Self::DataType) -> Self::DType {
        self.soft_plus_n1(&state, *one_data)
    }
    fn loss_i_scaled(&self, state: &Self::State, one_data: &Self::DataType) -> Self::DType {
        self.soft_plus_n1(&state, *one_data) * convert::<_, N>(self.data_size as f64)
    }
    fn loss_0(&self, state: &Self::State) -> Option<Self::DType> {
        Some(self.m / convert::<_, N>(2.0) * state.norm_squared())
    }
    fn all_loss(&self, state: &Self::State) -> Self::DType {
        let t1 = self.m / convert::<_, N>(2.0) * state.norm_squared();
        let t2 = self.soft_plus_ns(&state).sum();
        t1 + t2
    }
    fn gradient_i(&self, state: &Self::State, one_data: &Self::DataType) -> Self::State {
        -self.yas.column(*one_data) / (N::one() + self.exp_yax1(state, *one_data))
    }
    fn gradient_i_scaled(&self, state: &Self::State, one_data: &Self::DataType) -> Self::State {
        -self.yas.column(*one_data) / (N::one() + self.exp_yax1(state, *one_data))
            * convert::<_, N>(self.data_size as f64)
    }
    fn gradient_0(&self, state: &Self::State) -> Option<Self::State> {
        Some(state * self.m)
    }
    fn all_gradient(&self, state: &Self::State) -> Self::State {
        let t1 = state * self.m;
        t1 + self.all_gradient_except_0(state)
    }
    fn all_gradient_except_0(&self, state: &Self::State) -> Self::State {
        let t2_coeff = -self.ys.component_div(&self.exp_yaxs(state).map(|e| e + N::one()));
        let mut t2 = state.map(|_| N::zero());
        for (a, c) in self.a_s.column_iter().zip(t2_coeff.iter()) {
            t2 += a * *c;
        }
        t2
    }
}
impl<N, D> WithProxyGradient for Logistic<N, D>
where
    N: RealField,
    D: Dim,
    DefaultAllocator: Allocator<N, D, Dynamic>,
    N: Sum<N>,
    DefaultAllocator: Allocator<N, D>,
    // <DefaultAllocator as Allocator<N, D>>::Buffer: Copy + Send,
{
    type GradientProxy = N;
    fn proxy_gradient_i(&self, state: &Self::State, one_data: &Self::DataType) -> Self::GradientProxy {
        self.exp_yax1(state, *one_data)
    }
    fn to_real_gradient(&self, proxy: &Self::GradientProxy, one_data: &Self::DataType) -> Self::State {
        -self.yas.column(*one_data) / (N::one() + *proxy)
    }
    fn to_real_gradient_scaled(&self, proxy: &Self::GradientProxy, one_data: &Self::DataType) -> Self::State {
        self.to_real_gradient(proxy, one_data) * convert::<_, N>(self.N() as f64)
    }
}
