use ndarray::prelude::*;
use ndarray_stats::QuantileExt;
use ot::prelude::*;
use rust_optimal_transport as ot;
use rand::Rng;

use std::sync::{Arc, Mutex};

use crate::database::db_type::{WinDistType, RWDist};
use super::cluster_type::{KmeansPP, WinDistNode, WinDistNodePtr, WinDistCenter, WinDistCenterPtr};
use super::cluster_trait::{NodeTrait, CenterTrait};
use super::simple_thread_pool::{ThreadPool};

pub fn cul_emd_distance() -> f64 {
    -1.0
}

pub fn cul_emd_distance_test(x: &RWDist, y: &RWDist) -> f64 {
    //TODO: 权重！！！
    let win_dist_a:Vec<f64> = x.win_dist.iter().map(|&x| x as f64).collect();
    let win_dist_b:Vec<f64> = y.win_dist.iter().map(|&y| y as f64).collect();
    let mut n = Array1::from(win_dist_a);
    let mut m = Array1::from(win_dist_b);

    let d = Vec::from(
        [0.0, 1.0, 2.0, 
        1.0, 0.0, 1.0,
        2.0, 1.0, 0.0]);
        
    let mut trans = Array2::from_shape_vec((3, 3) ,d).unwrap();

    let result =
        match EarthMovers::new(&mut n, &mut m, &mut trans)
            .solve()
        {
            Ok(result) => result,
            Err(error) => panic!("{:?}", error),
        };
    let finnal_result = (result*trans).sum();
    finnal_result
}

pub trait KMeansPPTrait{
    fn choose_init(& mut self) -> Result<bool, String>; //直到返回false 完成初始化
    fn work(& mut self) -> Result<bool, String>;
}

impl KMeansPPTrait for KmeansPP{
    fn choose_init(& mut self) -> Result<bool, String>{
        if self.cluster_set.lock().unwrap().is_empty(){
            //random choose
            let mut rng = rand::thread_rng();
            let idx = rng.gen_range(0..=(self.point_set.len()-1));
            self.cluster_set.lock().unwrap().push( // 为什么要获取两次锁 phil
                Arc::new(
                    Mutex::new(
                        WinDistCenter::new(self.point_set.get(idx).unwrap().clone())
                    )
                )
            );
            println!("choose first one:{:?}", self.cluster_set);
            return Ok(true);
        }
        if self.cluster_set.lock().unwrap().len() as i32 >= self.cluster_num {
            return Ok(false);
        }

        //choose a point which is faster than other centroid points
        let mut next_centroid: Option<WinDistNodePtr> = None;
        let mut min_dist: f64 = 0.0;
        for (_, point) in self.point_set.iter().enumerate() {
            let mut point_dist: f64 = 0.0;
            for (_, centroid) in self.cluster_set.lock().unwrap().iter().enumerate(){
                if(point.win_dist.win_dist_id 
                    == 
                    centroid.lock().unwrap().centroid.win_dist.win_dist_id){
                    point_dist = 0.0;
                    break;
                }
                let dist = point.cul_distance(centroid.clone()).unwrap();
                point_dist += dist;
            }
            if point_dist>min_dist {
                next_centroid = Some(point.clone());
                min_dist = point_dist;
            }
        }
        if let Some(next_centroid) = next_centroid{
            println!("choose one:{:?}", next_centroid);
            self.cluster_set.lock().unwrap().push(
                Arc::new(Mutex::new(WinDistCenter::new(next_centroid.clone())))
            );
            return Ok(true);
        }
        
        Ok(false)
    }

    fn work(& mut self) -> Result<bool, String>{
        let mut is_conv = false;
        let mut update_rount = 0;
        let num_threads = num_cpus::get();
        //let mut point_set = self.point_set.clone();
        while is_conv==false {
            is_conv = true;
            let mut pool = ThreadPool::new(num_threads);
            //for (_, point) in self.point_set.iter().enumerate(){
            let mut point_set: &Vec<Arc<WinDistNode>> = self.point_set.as_ref();
            
            for (_, point) in point_set.iter().enumerate(){// 这个地方分一下组，不用每个数据都开一个线程 phil
                let mut cluster_set = self.cluster_set.clone();
                let mut aim_point = point.clone();
                pool.execute(move ||{
                    let mut min_centorid: WinDistCenterPtr = 
                    Arc::new(Mutex::new(WinDistCenter::new(
                        Arc::new(WinDistNode::new(-1, [0, 0, 0], 0))
                    )));
                    let mut min_dist = std::f64::MAX;
                    
                    //for (_, centroid) in self.cluster_set.lock().unwrap().iter().enumerate(){
                    for (_, centroid) in cluster_set.lock().unwrap().iter().enumerate(){
                        
                        if(aim_point.win_dist.win_dist_id 
                            == 
                            centroid.lock().unwrap().centroid.win_dist.win_dist_id){
                            return ();
                        }
                        
                        let dist = aim_point.cul_distance(centroid.clone()).unwrap();
                        //let dist = 0.0;
                        println!("dist:{:?} -- min_dist:{:?}", dist, min_dist);
                        if dist < min_dist{
                            min_dist = dist;
                            min_centorid = centroid.clone();
                        }
                    }
                    min_centorid.lock().unwrap().point_set.push(aim_point.clone());
                    
                });
            }
            pool.shutdown();
            for (_, centroid) in self.cluster_set.lock().unwrap().iter().enumerate(){
                centroid.lock().unwrap().update_centorid();
                is_conv &= centroid.lock().unwrap().has_diff().unwrap(); 
            }
            update_rount+=1;
            println!("update round:{:?}--is_conv:{:?}", update_rount, is_conv);
        }
        Ok(true)
    }
}