use crate::game::component::{self, *};
use rand::distributions::weighted;
use rocksdb::{ColumnFamilyDescriptor, DBCommon, IteratorMode, Options, SingleThreaded, DB};
use rust_optimal_transport::exact::EarthMovers;
use std::{borrow::BorrowMut, collections::HashSet};
use std::marker::PhantomData;
use num::traits::cast::ToPrimitive;
use rayon::iter::{IntoParallelRefIterator, IntoParallelIterator, IndexedParallelIterator, ParallelIterator, IntoParallelRefMutIterator};
use std::sync::{Arc, atomic::{AtomicUsize, Ordering}};
use super::kmeanspp::*;
use ndarray::prelude::*;
use rust_optimal_transport as ot;
use ot::prelude::*;
use super::*;

pub struct KrwEmdCluster<T> {
    kriso2bucket_street: Vec<Vec<usize>>,
    bucket_size_street: Vec<usize>,
    abstr_configs: Vec<AbstractAlgorithmStreet>,
    _marker: PhantomData<T>
}

impl<T> KrwEmdCluster<T>
    where T: Singleton + Hand + WaughTrait + ShowdownRanker + 'static,
{
    /// 使用streets设置需要聚类的街
    
    pub fn new(alg_configs: &[AbstractAlgorithmStreet]) -> Self{
        /// 有空把这段测试代码保留下来
        // for street in 0..T::NUM_STREET as usize {
        //     for recall_from in 0..=street {
        //         println!("street: {}, recall_from: {}", street, recall_from);
                
        //         let (iso2bucket, bucket_size) = load_isomorphism_abstr::<T>(street, recall_from);
        //         let unique: HashSet<_> = iso2bucket.iter().cloned().collect();
        //         assert_eq!(unique.len(), bucket_size);
        //         println!("isomorphism: {}", bucket_size);

        //         let (iso2bucket, bucket_size) = load_krxi_abstr::<T>(street, recall_from, true);
        //         let unique: HashSet<_> = iso2bucket.iter().cloned().collect();
        //         assert_eq!(unique.len(), bucket_size);
        //         println!("winning: {}", bucket_size);

        //         if (street != recall_from) || (street+1 != T::NUM_STREET as usize){
        //             let (iso2bucket, bucket_size) = load_krxi_abstr::<T>(street, recall_from, false);
        //             let unique: HashSet<_> = iso2bucket.iter().cloned().collect();
        //             assert_eq!(unique.len(), bucket_size);
        //             println!("potential: {}", bucket_size);                
        //         }
        //     }
        // }

        assert_eq!(alg_configs.len(), T::NUM_STREET as usize);
        let mut kriso2bucket_street = vec![];
        let mut bucket_size_street= vec![];
        for street in (0..alg_configs.len()).rev(){
            if let AbstractAlgorithmStreet::KrwEmd { recall_from, st_weights, centroid_size, train_iteration } = &alg_configs[street] {
                let (recall_from, centroid_size, train_iteration) = (*recall_from, *centroid_size, *train_iteration);
                assert!(recall_from <= street);
                assert_eq!(st_weights.len(), street-recall_from+1);
                let kriso_assign = Self::get_street_cluster(street, recall_from, st_weights.clone(), centroid_size, train_iteration);
                let bucket_size = centroid_size;
                kriso2bucket_street.insert(0, kriso_assign);
                bucket_size_street.insert(0, bucket_size);
            }
            else if let AbstractAlgorithmStreet::Isomorphism{recall_from} = alg_configs[street] {
                assert!(recall_from <= street);
                let (kriso_assign, bucket_size) = load_isomorphism_abstr::<T>(street, recall_from);
                kriso2bucket_street.insert(0, kriso_assign);
                bucket_size_street.insert(0, bucket_size);
            }
            else if let AbstractAlgorithmStreet::Krwi{recall_from} = alg_configs[street] {
                assert!(recall_from <= street);
                let (kriso_assign, bucket_size) = load_krxi_abstr::<T>(street, recall_from, true);
                kriso2bucket_street.insert(0, kriso_assign);
                bucket_size_street.insert(0, bucket_size);
            }
            else if let AbstractAlgorithmStreet::Kroi{recall_from} = alg_configs[street] {
                assert!(recall_from <= street);
                let (kriso_assign, bucket_size) = load_krxi_abstr::<T>(street, recall_from, false);
                kriso2bucket_street.insert(0, kriso_assign);
                bucket_size_street.insert(0, bucket_size);
            }   
            else {
                panic!("不能到这里");
            }
        }

        Self {
            kriso2bucket_street,
            bucket_size_street,
            abstr_configs: alg_configs.iter().cloned().collect::<Vec<_>>(),
            _marker: PhantomData,
        }
    }

    // pub fn new(alg_configs: &[AbstractAlgorithmStreet]) -> Self{
    //     assert_eq!(alg_configs.len(), T::NUM_STREET as usize);
    //     let mut priso2bucket_street = vec![];
    //     let mut bucket_size_street= vec![];
    //     for street in (0..alg_configs.len()).rev(){
    //         if let AbstractAlgorithmStreet::KrwEmd { recall_from, st_weights, centroid_size, train_iteration } = &alg_configs[street] {
    //             let (recall_from, centroid_size, train_iteration) = (*recall_from, *centroid_size, *train_iteration);
    //             assert!(recall_from <= street);
    //             assert_eq!(st_weights.len(), street-recall_from+1);
    //             let priso_assign = Self::get_street_cluster(street, recall_from, st_weights.clone(), centroid_size, train_iteration);
    //             let bucket_size = centroid_size;
    //             priso2bucket_street.insert(0, priso_assign);
    //             bucket_size_street.insert(0, bucket_size);
    //         }
    //         else if let AbstractAlgorithmStreet::Isomorphism{recall_from} = alg_configs[street] {
    //             assert!(recall_from <= street);
    //             let (priso_assign, bucket_size) = load_isomorphism_abstr::<T>(street, recall_from);
    //             priso2bucket_street.insert(0, priso_assign);
    //             bucket_size_street.insert(0, bucket_size);
    //         }
    //         else if let AbstractAlgorithmStreet::Krwi{recall_from} = alg_configs[street] {
    //             assert!(recall_from <= street);
    //             let (priso_assign, bucket_size) = load_krxi_abstr::<T>(street, recall_from, true);
    //             priso2bucket_street.insert(0, priso_assign);
    //             bucket_size_street.insert(0, bucket_size);
    //         }
    //         else if let AbstractAlgorithmStreet::Kroi{recall_from} = alg_configs[street] {
    //             assert!(recall_from <= street);
    //             let (priso_assign, bucket_size) = load_krxi_abstr::<T>(street, recall_from, false);
    //             priso2bucket_street.insert(0, priso_assign);
    //             bucket_size_street.insert(0, bucket_size);
    //         }   
    //         else {
    //             panic!("不能到这里");
    //         }
    //     }

    //     Self {
    //         priso2bucket_street,
    //         bucket_size_street,
    //         abstr_configs: alg_configs.iter().cloned().collect::<Vec<_>>(),
    //         _marker: PhantomData,
    //     }
    // }

    pub fn get_street_cluster(street:usize, recall_from:usize, st_weight: Vec<f64>, centroids_size:usize, max_iter:usize) -> Vec<usize> {
        assert_eq!(street-recall_from+1, st_weight.len());
        
        // 1.读数据
        // kriso2distid, krid2dist, krid2equity
        let (kriso2distid, distsize) = Self::load_wtdistid_with_kriso(street, recall_from);
        let krid2wsdistnorm = Self::load_wsdistnorm_with_id(street, recall_from);
        assert_eq!(distsize, krid2wsdistnorm.len());
        assert!(centroids_size <= distsize);

        // // 2.1 找weight
        // let krid2weight: Arc<Vec<AtomicUsize>> = Arc::new((0..krid2wsdistnorm.len()).into_par_iter().map(|_| AtomicUsize::new(0)).collect());
        // kriso2distid
        //     .par_iter()
        //     .enumerate()
        //     .for_each( |(kriso, &distid)| {
        //         let adder = T::instance().hand_index_volumn(kriso, street, recall_from);
        //         krid2weight[distid].fetch_add(adder, Ordering::Relaxed);
        //     });
        // let krid2weight: Vec<usize> = krid2weight
        //     .par_iter()
        //     .map(|x| x.load(Ordering::Relaxed))
        //     .collect();
        // let checksum = krid2weight.par_iter().sum::<usize>();

        // 2.2读weight
        let krid2weight = Self::load_weight_with_wtdistid(street, recall_from);

        // 3.写两个函数
        let average_closure = {
            move |pnt_idx: &[usize], pnts: &[Vec<Array1<f64>>], w8s: &[usize]| -> Vec<Array1<f64>>{
                let weight_sum = pnt_idx.par_iter().map(|&pntidx| w8s[pntidx]).sum::<usize>().to_f64().unwrap();

                let norm_centroid = pnt_idx
                    .par_iter()
                    .map(|&pntidx| {
                        let mut weightednorm_pnt = vec![];
                        pnts[pntidx]
                            .iter()
                            .for_each(|st_wdist|{
                                weightednorm_pnt.push(&(*st_wdist) * (w8s[pntidx].to_f64().unwrap() / weight_sum));
                            });
                        weightednorm_pnt
                    })
                    .reduce( || vec![Array1::<f64>::zeros(3); pnts[0].len()], |mut a, b|{
                        a.iter_mut()
                            .zip(b.iter())
                            .for_each(|(ast_dist, bst_dist)| {
                                *ast_dist += &(*bst_dist)
                            });
                        a
                    });

                norm_centroid
                    .iter()
                    .for_each(|ndarr| {
                        assert!((ndarr.sum() - 1_f64).abs() < 1e-10);
                    });

                norm_centroid
            }
        };
        let distance_batch_closure = {
            let st_weight = st_weight.clone();
            let ldw_cost: Array2<f64> = ndarray::array![[0., 1., 2.],[1., 0., 1.],[2., 1., 0.]];

            move |pntidcs: &[usize], pnts: &[Vec<Array1<f64>>], ctrd: &Vec<Array1<f64>>| -> Vec<f64> {

                pntidcs
                    .par_iter()
                    .map(|&pntidx| {
                        let emds = pnts[pntidx]
                            .iter()
                            .cloned()
                            .zip(ctrd.iter().cloned())
                            .map(|(mut point, mut centroid)| {
                                let mut mut_ldw_cost =  ldw_cost.clone();
                                let transport = EarthMovers::new(&mut point, &mut centroid, &mut mut_ldw_cost).solve().unwrap();
                                (&transport * &ldw_cost).sum()
                            })
                            .collect::<Vec<_>>();

                        emds.iter().zip(st_weight.iter()).map(|(emd, weight)| emd*weight).sum::<f64>()
                    })
                    .collect::<Vec<f64>>()
            }
        };

        // 4.kmeanspp
        let kmeanspp: WeightedKmeansPP<Vec<Array1<f64>>, Vec<Array1<f64>>> = WeightedKmeansPP::<Vec<Array1<f64>>, Vec<Array1<f64>>> {
            points: krid2wsdistnorm,
            weights: krid2weight,
            average_fn: Box::new(average_closure),
            distance_batch_fn: Box::new(distance_batch_closure),
        };
        let (dist_assign, _, _) = kmeanspp.kmeanspp_process(centroids_size, max_iter, false);

        // // 5. dist_assign -> priso_assign
        // let priso_size = T::instance().hand_isomorphism_size_street(street, 0);
        // let priso_assign = (0..priso_size)
        //     .into_par_iter()
        //     .map(|priso| {
        //         let mut hand: Vec<u8> = vec![0;T::HAND_LEN_STREET[street] as usize];
        //         T::instance().hand_unindexify(priso, street, 0, hand.as_mut());
        //         let kriso = T::instance().hand_indexify(hand.as_ref(), street, recall_from);
        //         let krid = kriso2distid[kriso];
        //         let bucket = dist_assign[krid];
        //         bucket
        //     })
        //     .collect::<Vec<_>>();

        // priso_assign

        // 5. dist_assign -> kriso_assign
        let kriso_size = T::instance().hand_isomorphism_size_street(street, recall_from);
        let kriso_assign = (0..kriso_size)
            .into_par_iter()
            .map(|kriso| {
                let krid = kriso2distid[kriso];
                let bucket = dist_assign[krid];
                bucket
            })
            .collect::<Vec<_>>();

        kriso_assign
    }

    fn load_weight_with_wtdistid(street: usize, recall_from: usize) -> Vec<usize>{
        // 确定数据库、列族，以及打开数据库
        let path = std::format!("data/{}", T::GAME_NAME);
        let options = {
            let mut options = Options::default();
            options.create_if_missing(false);
            options
        };

        let wtdistid2weight_cf = if street == recall_from {
            format!(
                "{}_nrid_{}_to_winning_distribution_weight",
                T::GAME_NAME,
                street + 1
            )
        } else {
            format!(
                "{}_prid_{}_from_{}_to_winning_trace_distribution_weight",
                T::GAME_NAME,
                street + 1,
                recall_from + 1
            )
        };

        let cf_names = vec!["default", &wtdistid2weight_cf];
        let cf_descriptors: Vec<_> = cf_names
            .iter()
            .map(|cf_name| {
                let mut cf_opt = Options::default();
                cf_opt.create_if_missing(false);
                ColumnFamilyDescriptor::new(cf_name.clone(), cf_opt)
            })
            .collect();
        let db = DBCommon::<SingleThreaded,_>::open_cf_descriptors_read_only(&options, &path , cf_descriptors, false).expect(&format!("打不开这个数据库{}的列族{}", path, wtdistid2weight_cf));
        let cf_handle = db.cf_handle(&wtdistid2weight_cf).expect(&format!("没有这个列族:{}", wtdistid2weight_cf));

        // 把所有distid读出来，并且校验（校验的逻辑是最后一位是否是distsize，把所有数据去重之后长度是否为distsize）
        let mut dbcf_iter = db.iterator_cf(cf_handle, IteratorMode::Start);
        let mut wtdistid2weight: Vec<usize> = vec![];
        for item in dbcf_iter {
            let (keybytes, valuebytes) = item.unwrap();
            assert!(keybytes.len() == 4 && valuebytes.len() == 4);
            let wtdistid = u32::from_be_bytes((*keybytes).try_into().unwrap()) as usize;
            let weight = u32::from_be_bytes((*valuebytes).try_into().unwrap()) as usize;
            wtdistid2weight.push(weight);
            assert_eq!(wtdistid, wtdistid2weight.len()-1);
        }

        // 返回结果
        wtdistid2weight
    }

    fn load_wsdistnorm_with_id(street: usize, recall_from: usize) -> Vec<Vec<Array1<f64>>>{
        // 确定数据库、列族，以及打开数据库
        let path = std::format!("data/{}", T::GAME_NAME);
        let options = {
            let mut options = Options::default();
            options.create_if_missing(false);
            options
        };

        let krid2wtdist_cf = if street == recall_from {
            format!(
                "{}_nrid_{}_to_winning_distribution",
                T::GAME_NAME,
                street + 1
            )
        } else {
            format!(
                "{}_prid_{}_from_{}_to_winning_trace_distribution",
                T::GAME_NAME,
                street + 1,
                recall_from + 1
            )
        };

        let nrid2wdist_cfs = {
            let mut nrid2wdist_cfs = vec![];
            for st in (recall_from..=street).rev(){
                let nrid2wdist_cf = format!(
                    "{}_nrid_{}_to_winning_distribution",
                    T::GAME_NAME,
                    st + 1,
                );
                nrid2wdist_cfs.push(nrid2wdist_cf);
            }
            nrid2wdist_cfs
        };

        let cf_names = {
            let mut cf_names = vec!["default"];
            cf_names.push(&krid2wtdist_cf);
            if street > recall_from {
                nrid2wdist_cfs
                    .iter()
                    .for_each(|nrid2wdist_cf|{
                        cf_names.push(nrid2wdist_cf);
                    });
            }
            cf_names
        };

        let cf_descriptors: Vec<_> = cf_names
            .iter()
            .map(|cf_name| {
                let mut cf_opt = Options::default();
                cf_opt.create_if_missing(false);
                ColumnFamilyDescriptor::new(cf_name.clone(), cf_opt)
            })
            .collect();
        let db = DBCommon::<SingleThreaded,_>::open_cf_descriptors_read_only(&options, &path , cf_descriptors, false).expect(&format!("打不开这个数据库{}的列族{:?}", path, cf_names));
        
        // 先读取wtdist
        let cf_handle = db.cf_handle(&krid2wtdist_cf).expect(&format!("没有这个列族:{}", krid2wtdist_cf));
        // 把所有distid读出来，并且校验（校验的逻辑是最后一位是否是distsize，把所有数据去重之后长度是否为distsize）
        let mut dbcf_iter = db.iterator_cf(cf_handle, IteratorMode::Start);
        let mut krid2wtdist: Vec<Vec<u32>> = vec![];
        if street == recall_from {
            for item in dbcf_iter {
                let (keybytes, valuebytes) = item.unwrap();
                assert!(keybytes.len() == 4);
                let distid = u32::from_be_bytes((*keybytes).try_into().unwrap()) as usize;
                let _dist: Vec<u32> = bincode::deserialize(&valuebytes).unwrap();
                krid2wtdist.push(vec![distid as u32]);
                assert_eq!(distid, krid2wtdist.len()-1);
            }
        }
        else {
            for item in dbcf_iter {
                let (keybytes, valuebytes) = item.unwrap();
                assert!(keybytes.len() == 4);
                let distid = u32::from_be_bytes((*keybytes).try_into().unwrap()) as usize;
                let dist: Vec<u32> = bincode::deserialize(&valuebytes).unwrap();
                krid2wtdist.push(dist);
                assert_eq!(distid, krid2wtdist.len()-1);
            }
        }
        let krid2wtdist = krid2wtdist;
        
        // 读取各个相关的街的wdist
        let mut nrid2wdists: Vec<Vec<[i64; 3]>> = vec![];
        for (vecidx, st) in (recall_from..=street).rev().enumerate(){
            let mut nrid2wdist: Vec<[i64; 3]> = vec![];
            let cf_name = nrid2wdist_cfs[vecidx].as_ref();
            let cf_handle = db.cf_handle(cf_name).expect(&format!("没有这个列族:{}", cf_name));
            let mut dbcf_iter = db.iterator_cf(cf_handle, IteratorMode::Start);

            for item in dbcf_iter {
                let (keybytes, valuebytes) = item.unwrap();
                assert!(keybytes.len() == 4);
                let distid = u32::from_be_bytes((*keybytes).try_into().unwrap()) as usize;
                let dist: [i64; 3] = bincode::deserialize(&valuebytes).unwrap();
                nrid2wdist.push(dist);
                assert_eq!(distid, nrid2wdist.len()-1);
            }

            nrid2wdists.push(nrid2wdist);
        }
        let nrid2wdists = nrid2wdists;

        // 将winning trace转换成拼接好的winning distribution s 归一化的
        let krid2wsdistnorm = krid2wtdist
            .par_iter()
            // .iter()
            .map(|wtdist| {
                wtdist
                    .iter()
                    .enumerate()
                    .map(|(vecidx, &nrid)|{
                        let max_value = nrid2wdists[vecidx][nrid as usize].iter().sum::<i64>().to_f64().unwrap();
                        let wdist_norm = nrid2wdists[vecidx][nrid as usize]
                            .iter()
                            .map(|&component| {
                                component.to_f64().unwrap() / max_value
                            })
                            .collect::<Vec<_>>();
                        Array1::<f64>::from(wdist_norm)
                    })
                    .collect::<Vec<Array1<f64>>>()
            })
            .collect::<Vec<_>>();

        krid2wsdistnorm
    }  

    fn load_wtdistid_with_kriso(street: usize, recall_from: usize) -> (Vec<usize>, usize){
        // 确定数据库、列族，以及打开数据库
        let path = std::format!("data/{}", T::GAME_NAME);
        let options = {
            let mut options = Options::default();
            options.create_if_missing(false);
            options
        };

        let kriso2wtdistid_cf = if street == recall_from {
            format!(
                "{}_nriso_{}_to_winning_distribution_id",
                T::GAME_NAME,
                street + 1
            )
        } else {
            format!(
                "{}_priso_{}_from_{}_to_winning_trace_distribution_id",
                T::GAME_NAME,
                street + 1,
                recall_from + 1
            )
        };

        let cf_names = vec!["default", &kriso2wtdistid_cf];
        let cf_descriptors: Vec<_> = cf_names
            .iter()
            .map(|cf_name| {
                let mut cf_opt = Options::default();
                cf_opt.create_if_missing(false);
                ColumnFamilyDescriptor::new(cf_name.clone(), cf_opt)
            })
            .collect();
        let db = DBCommon::<SingleThreaded,_>::open_cf_descriptors_read_only(&options, &path , cf_descriptors, false).expect(&format!("打不开这个数据库{}的列族{}", path, kriso2wtdistid_cf));
        let cf_handle = db.cf_handle(&kriso2wtdistid_cf).expect(&format!("没有这个列族:{}", kriso2wtdistid_cf));

        // 读street中的最后一位数据，代表的是distsize
        let isosize = T::instance().hand_isomorphism_size_street(street, recall_from);
        let distsize = db.get_cf(&cf_handle, (isosize as u32).to_be_bytes()).expect(&format!("isosize/key:{}在{}街的值为None", isosize, street)).unwrap();
        let distsize = u32::from_be_bytes((*distsize).try_into().unwrap()) as usize;
        
        // 把所有distid读出来，并且校验（校验的逻辑是最后一位是否是distsize，把所有数据去重之后长度是否为distsize）
        let mut dbcf_iter = db.iterator_cf(cf_handle, IteratorMode::Start);
        let mut kriso2wtdistid = vec![distsize+1; isosize+1];
        for item in dbcf_iter {
            let (keybytes, valuebytes) = item.unwrap();
            assert!(keybytes.len() == 4 && valuebytes.len() == 4);
            let iso = u32::from_be_bytes((*keybytes).try_into().unwrap()) as usize;
            let distid = u32::from_be_bytes((*valuebytes).try_into().unwrap()) as usize;
            kriso2wtdistid[iso] = distid;
        }
        assert_eq!(kriso2wtdistid.pop().unwrap(), distsize);
        let unique: HashSet<_> = kriso2wtdistid.iter().cloned().collect();
        assert_eq!(unique.len(), distsize);

        // 返回结果
        (kriso2wtdistid, distsize)
    }

    pub fn save(&self, custom_name: &str) {
        let path_str = std::format!("data/CustomCluster/{}/KrwEmd/{}", T::GAME_NAME, custom_name);
        let path = Path::new(path_str.as_str());
        if path.exists() {
            fs::remove_dir_all(path).unwrap();
        }
        fs::create_dir_all(path).unwrap();

        // 创建一个选项对象
        let mut opts = Options::default();
        // 设置数据库如果不存在则创建
        opts.create_if_missing(true);
        let mut db = DB::open(&opts, path).unwrap();
     
        self.kriso2bucket_street
            .iter()
            .zip(self.bucket_size_street.iter())
            .enumerate()
            .for_each(|(street, (kriso2bucket, bucket_size))|{
                let recall_from = if let AbstractAlgorithmStreet::KrwEmd { recall_from,..} = &self.abstr_configs[street] {
                    *recall_from
                }
                else if let AbstractAlgorithmStreet::Isomorphism{recall_from} = self.abstr_configs[street] {
                    recall_from
                }
                else if let AbstractAlgorithmStreet::Krwi{recall_from} = self.abstr_configs[street] {
                    recall_from
                }
                else if let AbstractAlgorithmStreet::Kroi{recall_from} = self.abstr_configs[street] {
                    recall_from
                }   
                else {
                    panic!("不能到这里");
                };
                assert!(street >= recall_from, "street:{}, recall_from: {}", street, recall_from);
                save_cluster_bucket_street::<T>(&mut db, street, recall_from, kriso2bucket.as_ref(), *bucket_size);
            });

        save_cluster_configs_yaml(path, self.abstr_configs.as_ref());
    }
}