use rand::Rng;
use rayon::prelude::*;
use std::sync::{Arc, Mutex};

use rayon::iter::IntoParallelRefIterator;

// pub struct WeightedKmeansPP<P, C> {
//     pub points: Vec<P>,
//     pub weights: Vec<usize>,
//     pub average_fn: fn(pnt_idx: &[usize], pnts: &[P], w8s: &[usize]) -> C,
//     pub distance_fn: fn(pnt: &P, ctrd: &C) -> f64,
// } 

// impl<P, C> WeightedKmeansPP<P, C> 
//     where C: Send+Sync,
//           P: Send+Sync
// {
//     pub fn kmeanspp_process(&self, centroids_size: usize, max_iter: usize) -> (Vec<usize>, Vec<C>) {

//         // pp
//         let mut assign = vec![centroids_size; self.points.len()];
//         let mut centroid_pointidcs:Vec<Arc<Mutex<Vec<usize>>>> = (0..centroids_size).into_par_iter().map(|_| Arc::new(Mutex::new(vec![]))).collect();
//         let mut mindis = vec![f64::MAX; self.points.len()];
//         let mut centroids: Vec<C> = vec![];
//         for centroid_cnt in 0..centroids_size {
//             // 找点
//             let new_point_centroid = if centroid_cnt == 0 {
//                 rand::thread_rng().gen_range(0..self.points.len())
//             } else {
//                 mindis
//                     .par_iter()
//                     .enumerate()
//                     .max_by(|(_, dis_a), (_, dis_b)| {
//                         dis_a.partial_cmp(dis_b).unwrap()
//                     })
//                     .map(|(pointidx, _)| pointidx)
//                     .unwrap()
//             };
//             assign[new_point_centroid] = centroid_cnt;
//             centroid_pointidcs[centroid_cnt].lock().unwrap().push(new_point_centroid);

//             // 构建centroid
//             let new_centorid = (self.average_fn)(centroid_pointidcs[centroid_cnt].lock().unwrap().as_ref(), self.points.as_ref(), self.weights.as_ref());
//             centroids.push(new_centorid);

//             // 找最小点
//             mindis
//                 .par_iter_mut()
//                 .enumerate()
//                 .for_each(|(idx, dis)| {
//                     let update_dis = (self.distance_fn)(&self.points[idx], centroids.last().unwrap());
//                     *dis = dis.min(update_dis);
//                 })
//         }

//         // cluster
//         let mut new_assign = vec![centroids_size; self.points.len()];
//         let mut is_conv = false;
//         for itercnt in 0..max_iter {
//             // point new_assign to centroid
//             self.points
//                 .par_iter()
//                 .zip(new_assign.par_iter_mut())
//                 .for_each(|(point, point2centroid)|{
//                     *point2centroid = centroids
//                         .iter()
//                         .enumerate()
//                         .min_by(|(_, ca), (_,cb)|{
//                             (self.distance_fn)(point, ca)
//                                 .partial_cmp(&(self.distance_fn)(point, cb))
//                                 .unwrap()
//                         })
//                         .map(|(centroid_idx, _)| centroid_idx)
//                         .unwrap();
//                 });

//             // update centroid
//             centroid_pointidcs
//                 .par_iter_mut()
//                 .for_each(|v| v.lock().unwrap().clear());
//             new_assign
//                 .par_iter()
//                 .enumerate()
//                 .for_each(|(pntidx, &centroid_idx)| {
//                     centroid_pointidcs[centroid_idx].lock().unwrap().push(pntidx);
//                 });
//             centroids
//                 .par_iter_mut()
//                 .enumerate()
//                 .for_each(|(centroid_idx, centroid)| {
//                     *centroid = (self.average_fn)(centroid_pointidcs[centroid_idx].lock().unwrap().as_ref(), self.points.as_ref(), self.weights.as_ref());
//                 });

//             // is_conv
//             is_conv = new_assign.par_iter().zip(assign.par_iter()).all(|(a,b)| a == b);
//             if is_conv {
//                 println!("update iter:{:?}--is_conv", itercnt);
//                 break;
//             }
//             assign = new_assign.clone();
//         }

//         if !is_conv {
//             println!("is not converge until {} iteration", max_iter);
//         }

//         (assign, centroids)
//     }
// }


// pub struct WeightedKmeansPP<P, C, F1, F2> 
//     where F1: Fn(pntidcs: &[usize], pnts: &[P], w8s: &[usize]) -> C,
//           F2: Fn Fn(pntidcs: &[usize], pnts: &[P], ctrd: &C) -> Vec<f64>,
// {
//     pub points: Vec<P>,
//     pub weights: Vec<usize>,
//     pub average_fn: F1,
//     pub distance_batch_fn: F2,
// } 

// pub struct WeightedKmeansPP<P, C> {
//     pub points: Vec<P>,
//     pub weights: Vec<usize>,
//     pub average_fn: Box<dyn Fn(&[usize], &[P], &[usize]) -> C>,
//     pub distance_batch_fn: Box<dyn Fn(&[usize], &[P], &C) -> Vec<f64>>,
// }
pub struct WeightedKmeansPP<P, C> 
{
    pub points: Vec<P>,
    pub weights: Vec<usize>,
    pub average_fn: Box<dyn Fn(/*pntidcs: */&[usize], /*pnts: */&[P], /*w8s: */&[usize]) -> C + Sync + Send>,
    pub distance_batch_fn: Box<dyn Fn(/*pntidcs: */&[usize], /*pnts: */&[P], /*ctrd: */&C) -> Vec<f64> + Sync + Send>,
} 

impl<P, C> WeightedKmeansPP<P, C> 
    where C: Send+Sync,
          P: Send+Sync
{
    pub fn kmeanspp_process(&self, centroids_size: usize, max_iter: usize, fix_mindis_point: bool) -> (Vec<usize>, Vec<Vec<usize>>, Vec<C>) {
        // assist
        let total_pntidcs: Vec<usize> = (0..self.points.len()).collect();

        // pp
        let mut assign = vec![centroids_size; self.points.len()];
        let mut centroid_pointidcs:Vec<Arc<Mutex<Vec<usize>>>> = (0..centroids_size).into_par_iter().map(|_| Arc::new(Mutex::new(vec![]))).collect();
        let mut mindis = vec![f64::MAX; self.points.len()];
        let mut centroids: Vec<C> = vec![];
        let mut centroid_mindis_point = if fix_mindis_point {
            Some(vec![])
        } else {
            None
        };

        for centroid_cnt in 0..centroids_size {
            // 找点
            let new_point_centroid = if centroid_cnt == 0 {
                rand::thread_rng().gen_range(0..self.points.len())
            } else {
                mindis
                    .par_iter()
                    .enumerate()
                    .max_by(|(_, dis_a), (_, dis_b)| {
                        dis_a.partial_cmp(dis_b).unwrap()
                    })
                    .map(|(pointidx, dis)| {assert!(*dis >0.0); pointidx})
                    .unwrap()
            };
            assign[new_point_centroid] = centroid_cnt;
            mindis[new_point_centroid] = 0.0; // 以免distance(a, a) != 0的情况
            centroid_pointidcs[centroid_cnt].lock().unwrap().push(new_point_centroid);
            if let Some(centroid_mindis_point) = &mut centroid_mindis_point {
                centroid_mindis_point.push(new_point_centroid);
            };

            // 构建centroid
            let new_centorid = (self.average_fn)(centroid_pointidcs[centroid_cnt].lock().unwrap().as_ref(), self.points.as_ref(), self.weights.as_ref());
            centroids.push(new_centorid);

            // 找最小点           
            (self.distance_batch_fn)(total_pntidcs.as_ref(), self.points.as_ref(), centroids.last().unwrap())
                .into_par_iter()
                .zip(mindis.par_iter_mut())
                .for_each(|(new_dis, dis)| {
                    *dis = dis.min(new_dis);
                });
        }

        // cluster
        let mut new_assign = vec![centroids_size; self.points.len()];
        let mut is_conv = false;
        for itercnt in 0..max_iter {
            // init
            mindis
                .par_iter_mut()
                .zip(new_assign.par_iter_mut())
                .for_each(|(dis, centroid_idx)| {
                    *dis = f64::MAX;
                    *centroid_idx = centroids_size;
                });

            // point new_assign to centroid
            // 加了额外的逻辑，归到centroid的点中距离这个centroid最近的点不会改变分类类型
            centroids
                .iter()
                .enumerate()
                .for_each(|(centroid_idx, centroid)| {
                    (self.distance_batch_fn)(total_pntidcs.as_ref(), self.points.as_ref(), centroid)
                        .into_par_iter()
                        .zip(mindis.par_iter_mut())
                        .zip(new_assign.par_iter_mut())
                        .for_each(|((newdis, dis), point2centroid)| {
                            if newdis < *dis {
                                *dis = newdis;
                                *point2centroid = centroid_idx;
                            }
                        })
                });
            if let Some(centroid_mindis_point) = &centroid_mindis_point {
                centroid_mindis_point
                    .iter()
                    .enumerate()
                    .for_each(|(centroid_idx, &pnt_idx)|{
                        new_assign[pnt_idx] = centroid_idx;
                    });
            };

            // update centroid
            centroid_pointidcs
                .par_iter_mut()
                .for_each(|v| v.lock().unwrap().clear());
            new_assign
                .par_iter()
                .enumerate()
                .for_each(|(pntidx, &centroid_idx)| {
                    centroid_pointidcs[centroid_idx].lock().unwrap().push(pntidx);
                });
            centroids
                .par_iter_mut()
                .enumerate()
                .for_each(|(centroid_idx, centroid)| {
                    *centroid = (self.average_fn)(centroid_pointidcs[centroid_idx].lock().unwrap().as_ref(), self.points.as_ref(), self.weights.as_ref());
                });
            if let Some(centroid_mindis_point) = &mut centroid_mindis_point {
                centroid_mindis_point
                    .iter_mut()
                    .enumerate()
                    .for_each(|(centroid_idx, pnt_idx)|{
                        let inner_centroid_distances = (self.distance_batch_fn)(centroid_pointidcs[centroid_idx].lock().unwrap().as_ref(), self.points.as_ref(), &centroids[centroid_idx]);
                        let innerpnt_min_idx = inner_centroid_distances
                            .par_iter()
                            .zip(centroid_pointidcs[centroid_idx].lock().unwrap().par_iter())
                            .min_by(|(dis_a, _), (dis_b, _)|{
                                dis_a.partial_cmp(dis_b).unwrap()
                            })
                            .map(|(_, &innerpoint_idx)| innerpoint_idx)
                            .unwrap();
                        *pnt_idx = innerpnt_min_idx;
                    });
            };

            // is_conv
            is_conv = new_assign.par_iter().zip(assign.par_iter()).all(|(a,b)| a == b);
            if is_conv {
                println!("update iter:{:?}--is_conv", itercnt);
                break;
            }
            else if itercnt%20 == 0 {
            // else {
                println!("have excute {} iteration", itercnt);
            }
            assign = new_assign.clone();
        }

        if !is_conv {
            println!("is not converge until {} iteration", max_iter);
        }

        let centroid_pointidcs = centroid_pointidcs
            .into_par_iter()
            .map(|arc_mutex_vec| {
                arc_mutex_vec.lock().unwrap().clone()
            })
            .collect::<Vec<Vec<usize>>>();

        (assign, centroid_pointidcs, centroids)
    }
}