use itertools::Itertools;
use rand::prelude::*;
use rand_distr::{Distribution, Normal, StandardNormal, Uniform};

use super::{SumDecomposableModel, WithProxyGradient};
use nalgebra::allocator::Allocator;
use nalgebra::*;

#[allow(non_snake_case)]
#[derive(Debug)]
pub struct GaussianModel<N, D>
where
    N: RealField,
    D: DimName,
    DefaultAllocator: Allocator<N, D>,
    DefaultAllocator: Allocator<N, D, D>,
{
    data: Vec<OVector<N, D>>,
    inv_sigma: OMatrix<N, D, D>,
    sqrt_sigma: OMatrix<N, D, D>,
    L: N,
    pub mean_data: OVector<N, D>,
    pub optimal: N,
}

impl<N, D> SumDecomposableModel for GaussianModel<N, D>
where
    N: RealField,
    N: std::iter::Sum,
    D: DimName,
    DefaultAllocator: Allocator<N, D>,
    DefaultAllocator: Allocator<N, D, D>,
{
    type DType = N;
    type DataType = OVector<N, D>;
    type State = Self::DataType;

    fn N(&self) -> usize {
        self.data.len()
    }
    fn all_data(&self) -> &[Self::DataType] {
        &self.data
    }
    fn loss_i(&self, state: &Self::State, one_data: &Self::DataType) -> Self::DType {
        let diff = state - one_data;
        diff.dot(&(&self.inv_sigma * &diff)) / convert((2 * self.N()) as f64)
    }
    fn loss_i_scaled(&self, state: &Self::State, one_data: &Self::DataType) -> Self::DType {
        let diff = state - one_data;
        diff.dot(&(&self.inv_sigma * &diff)) / convert(2.0)
    }
    fn loss_0(&self, _state: &Self::State) -> Option<Self::DType> {
        None
    }
    fn all_loss(&self, state: &Self::State) -> Self::DType {
        let diff = state - &self.mean_data;
        diff.dot(&(&self.inv_sigma * &diff)) / convert(2.0) + self.optimal
    }
    fn gradient_i(&self, state: &Self::State, one_data: &Self::DataType) -> Self::State {
        (&self.inv_sigma * (state - one_data)) / convert::<_, N>(self.N() as f64)
    }
    fn gradient_i_scaled(&self, state: &Self::State, one_data: &Self::DataType) -> Self::State {
        &self.inv_sigma * (state - one_data)
    }
    fn gradient_0(&self, _state: &Self::State) -> Option<Self::State> {
        None
    }
    fn all_gradient(&self, state: &Self::State) -> Self::State {
        self.all_gradient_except_0(state)
    }
    fn all_gradient_except_0(&self, state: &Self::State) -> Self::State {
        &self.inv_sigma * (state - &self.mean_data)
    }
}

impl<N, D> WithProxyGradient for GaussianModel<N, D>
where
    N: RealField,
    N: std::iter::Sum,
    D: DimName,
    DefaultAllocator: Allocator<N, D>,
    DefaultAllocator: Allocator<N, D, D>,
    <DefaultAllocator as Allocator<N, D>>::Buffer: Copy + Send,
{
    type GradientProxy = Self::State;
    fn proxy_gradient_i(&self, state: &Self::State, one_data: &Self::DataType) -> Self::GradientProxy {
        self.gradient_i(state, one_data)
    }
    fn to_real_gradient(&self, proxy: &Self::GradientProxy, _one_data: &Self::DataType) -> Self::State {
        *proxy
    }
    fn to_real_gradient_scaled(&self, proxy: &Self::GradientProxy, _one_data: &Self::DataType) -> Self::State {
        *proxy * convert::<_, N>(self.N() as f64)
    }
}

impl<N, D> GaussianModel<N, D>
where
    N: RealField,
    D: DimName,
    DefaultAllocator: Allocator<N, D>,
    DefaultAllocator: Allocator<N, D, D>,
{
    /// randomly generate data
    fn generate_data(sample_size: usize, rng: &mut StdRng) -> Vec<OVector<N, D>> {
        let normal = Normal::new(2.0, 2.0).unwrap();
        normal
            .sample_iter(rng)
            .take(sample_size * D::dim())
            .chunks(D::dim())
            .into_iter()
            .map(|chunk| OVector::<N, D>::from_vec(chunk.map(|x| convert(x)).collect::<Vec<_>>()))
            .collect()
    }

    /// initialize state variable
    pub fn init_state(&self, rng: &mut StdRng) -> OVector<N, D> {
        let normal = StandardNormal;
        &self.sqrt_sigma
            * OVector::<N, D>::from_vec(normal.sample_iter(rng).take(D::dim()).map(|x| convert(x)).collect())
            + &self.mean_data
    }
}

impl<N, D> GaussianModel<N, D>
where
    N: RealField,
    N: std::iter::Sum,
    D: DimName,
    D: DimMin<D>,
    DefaultAllocator: Allocator<N, D>,
    DefaultAllocator: Allocator<N, D, D>,
    DefaultAllocator: Allocator<N, <D as DimMin<D>>::Output>,
    DefaultAllocator: Allocator<N, D, DimMinimum<D, D>>,
{
    /// initialize gaussian model
    #[allow(non_snake_case)]
    pub fn init(mut rng: &mut StdRng, sample_size: usize, L: N) -> Self {
        assert!(D::dim() >= 2);
        let data = Self::generate_data(sample_size, &mut rng);
        let mean_data = data.iter().sum::<OVector<N, D>>() / convert::<_, N>(sample_size as f64);
        let (inv_sigma, sqrt_sigma) = Self::generate_sigma(&mut rng, L);
        let optimal = N::zero();
        let mut tmp = Self {
            data,
            inv_sigma,
            sqrt_sigma,
            L,
            mean_data,
            optimal,
        };
        tmp.optimal = tmp.data.iter().map(|d| tmp.loss_i(&tmp.mean_data, d)).sum();
        tmp
    }

    /// return inv_sigma and sqrt_sigma
    #[allow(non_snake_case)]
    fn generate_sigma(mut rng: &mut StdRng, L: N) -> (OMatrix<N, D, D>, OMatrix<N, D, D>) {
        let m = convert(1.0);
        let randomm = OMatrix::<N, D, D>::from_vec(
            Normal::new(2.0, 1.0)
                .unwrap()
                .sample_iter(&mut rng)
                .take(D::dim() * D::dim())
                .map(|x| convert(x))
                .collect(),
        );
        let qr = randomm.qr();
        let q = qr.q();
        let q = OMatrix::<N, D, D>::from_column_slice(q.as_slice());
        let randomv = OVector::<N, D>::from_vec(
            Uniform::from(0.0..100.0)
                .sample_iter(&mut rng)
                .take(D::dim())
                .map(|x| convert(x))
                .collect(),
        );
        let max = randomv.max();
        let min = randomv.min();
        let diag = (randomv - OVector::<N, D>::repeat(min)) / (max - min) * (L - m) + OVector::<N, D>::repeat(m);
        let inv_Sigma = &q * &OMatrix::<N, D, D>::from_diagonal(&diag) * q.transpose();
        let sqrt_Sigma =
            &q * &OMatrix::<N, D, D>::from_diagonal(&diag.map(|x| convert::<_, N>(1.0) / x.sqrt())) * q.transpose();
        (inv_Sigma, sqrt_Sigma)
    }
}
