use crate::game::component::*;
use rocksdb::{ColumnFamilyDescriptor, DBCommon, IteratorMode, Options, SingleThreaded, DB};
use std::collections::HashSet;
use std::marker::PhantomData;
use num::traits::cast::ToPrimitive;
use rayon::iter::{IntoParallelRefIterator, IntoParallelIterator, IndexedParallelIterator, ParallelIterator};
use std::sync::{Arc, atomic::{AtomicUsize, Ordering}};
use super::*;

pub struct EhsCluster<T> {
    kriso2bucket_street: Vec<Vec<usize>>,
    bucket_size_street: Vec<usize>,
    abstr_configs: Vec<AbstractAlgorithmStreet>,
    _marker: PhantomData<T>
}

impl<T> EhsCluster<T>
    where T: Singleton + Hand + WaughTrait + ShowdownRanker + 'static,
{
    // pub fn get_ehs() -> Self{
    pub fn new(alg_configs: &[AbstractAlgorithmStreet]) -> Self {
        assert_eq!(alg_configs.len(), T::NUM_STREET as usize);

        fn check_alg_configs(alg_configs: &[AbstractAlgorithmStreet]) -> bool {
            // 变量来记录迭代状态
            let mut in_isomorphism = true;
        
            // 检查前几个元素是否为连续的 Isomorphism
            for config in alg_configs {
                match config {
                    AbstractAlgorithmStreet::Isomorphism { .. } => {
                        if !in_isomorphism {
                            // 如果状态转换，但发现之前已经结束了 Isomorphism 部分，返回 false
                            return false;
                        }
                    }
                    AbstractAlgorithmStreet::Ehs { .. } => {
                        // 转换到 Ehs 部分
                        in_isomorphism = false;
                    }
                    _ => {
                        // 如果出现其他变体，返回 false
                        return false;
                    }
                }
            }
        
            // 检查最后几个元素是否全为 Ehs
            if in_isomorphism {
                // 如果迭代结束后仍在 Isomorphism 部分，返回 false
                return false;
            }
        
            true
        }
        assert!(check_alg_configs(alg_configs));

        let mut kriso2bucket_street = vec![];
        let mut bucket_size_street= vec![];
        for street in 0..alg_configs.len(){
            if let AbstractAlgorithmStreet::Isomorphism {recall_from} = alg_configs[street] {
                let (kriso_assign, bucket_size) = load_isomorphism_abstr::<T>(street, recall_from);
                kriso2bucket_street.push(kriso_assign);
                bucket_size_street.push(bucket_size);
            }
            else if let AbstractAlgorithmStreet::Ehs {centroid_size, train_iteration} = alg_configs[street] {
                let kriso_assign = Self::get_street_cluster(street, centroid_size, train_iteration);
                let bucket_size = centroid_size;
                kriso2bucket_street.push(kriso_assign);
                bucket_size_street.push(bucket_size);
            }
            else {
                panic!("不能到这里");
            }
        }
        
        Self {
            kriso2bucket_street,
            bucket_size_street,
            abstr_configs: alg_configs.iter().cloned().collect::<Vec<_>>(),
            _marker: PhantomData,    
        }
    }

    pub fn get_street_cluster(street:usize, centroids_size:usize, max_iter:usize) -> Vec<usize> {
        // 1.读数据
        // nriso2distid, nrid2dist, nrid2equity
        let (nriso2distid, distsize) = Self::load_distid_with_nriso(street);
        let nrid2dist = Self::load_dist_with_id(street);
        assert_eq!(distsize, nrid2dist.len());
        let nrid2equity :Vec<_> = nrid2dist
            .par_iter()
            .map(|dist| {
                dist[1] + 2 * dist[2]
            })
            .collect();
        {
            let unique: HashSet<_> = nrid2equity.iter().cloned().collect();
            assert!(centroids_size <= unique.len());
        }
        let nrid2equity :Vec<_> = nrid2dist
            .into_par_iter()
            .map(|dist| {
                (1.0 * dist[1].to_f64().unwrap() + 2.0 * dist[2].to_f64().unwrap())/dist.iter().sum::<i64>().to_f64().unwrap()
            })
            .collect();

        // 2.找weight
        let nrid2weight: Arc<Vec<AtomicUsize>> = Arc::new((0..nrid2equity.len()).into_par_iter().map(|_| AtomicUsize::new(0)).collect());
        nriso2distid
            .par_iter()
            .enumerate()
            .for_each( |(nriso, &distid)| {
                let recall_from=street;
                let adder = T::instance().hand_index_volumn(nriso, street, recall_from);
                nrid2weight[distid].fetch_add(adder, Ordering::Relaxed);
            });
        let nrid2weight: Vec<usize> = nrid2weight
            .par_iter()
            .map(|x| x.load(Ordering::Relaxed))
            .collect();
        // let checksum = nrid2weight.par_iter().sum::<usize>();

        // 3.写两个函数
        let average_closure = {
            move |pnt_idx: &[usize], pnts: &[f64], w8s: &[usize]| -> f64{
                let mut sum = 0.0;
                let mut weight_sum = 0;
        
                for &idx in pnt_idx {
                    sum += pnts[idx] * w8s[idx].to_f64().unwrap();  // 累加加权值
                    weight_sum += w8s[idx];  // 累加权重
                }
        
                sum / weight_sum.to_f64().unwrap()  // 计算加权平均值
            }
        };
        let distance_batch_closure = {
            move |pntidcs: &[usize], pnts: &[f64], ctrd: &f64| -> Vec<f64> {

                pntidcs
                    .par_iter()
                    .map(|&pntidx| {
                        (pnts[pntidx] - ctrd).abs()
                    })
                    .collect::<Vec<f64>>()
            }
        };

        // 4.kmeanspp
        let kmeanspp: WeightedKmeansPP<f64, f64> = WeightedKmeansPP::<f64, f64> {
            points: nrid2equity,
            weights: nrid2weight,
            average_fn: Box::new(average_closure),
            distance_batch_fn: Box::new(distance_batch_closure),
        };
        let (dist_assign, _, _) = kmeanspp.kmeanspp_process(centroids_size, max_iter, false);

        // 5. dist_assign -> nriso
        let nriso_size = T::instance().hand_isomorphism_size_street(street, street);
        let nriso_assign = (0..nriso_size)
            .into_par_iter()
            .map(|nriso| {
                let nrid = nriso2distid[nriso];
                let bucket = dist_assign[nrid];
                bucket
            })
            .collect::<Vec<_>>();

        nriso_assign
    }

    fn load_distid_with_nriso(street:usize) -> (Vec<usize>, usize){
        // 确定数据库、列族，以及打开数据库
        let path = std::format!("data/{}", T::GAME_NAME);
        let options = {
            let mut options = Options::default();
            options.create_if_missing(false);
            options
        };
        let nriso2distid_cf = format!(
            "{}_nriso_{}_to_winning_distribution_id",
            T::GAME_NAME,
            street+1
        );
        let cf_names = vec!["default", &nriso2distid_cf];
        let cf_descriptors: Vec<_> = cf_names
            .iter()
            .map(|cf_name| {
                let mut cf_opt = Options::default();
                cf_opt.create_if_missing(false);
                ColumnFamilyDescriptor::new(cf_name.clone(), cf_opt)
            })
            .collect();
        let db = DBCommon::<SingleThreaded,_>::open_cf_descriptors_read_only(&options, &path , cf_descriptors, false).expect(&format!("打不开这个数据库{}的列族{}", path, nriso2distid_cf));
        let cf_handle = db.cf_handle(&nriso2distid_cf).expect(&format!("没有这个列族:{}", nriso2distid_cf));

        // 读street中的最后一位数据，代表的是distsize
        let recall_from = street;
        let isosize = T::instance().hand_isomorphism_size_street(street, recall_from);
        let distsize = db.get_cf(&cf_handle, (isosize as u32).to_be_bytes()).expect(&format!("isosize/key:{}在{}街的值为None", isosize, street)).unwrap();
        let distsize = u32::from_be_bytes((*distsize).try_into().unwrap()) as usize;
        
        // 把所有distid读出来，并且校验（校验的逻辑是最后一位是否是distsize，把所有数据去重之后长度是否为distsize）
        let mut dbcf_iter = db.iterator_cf(cf_handle, IteratorMode::Start);
        let mut nriso2distid = vec![distsize+1; isosize+1];
        for item in dbcf_iter {
            let (keybytes, valuebytes) = item.unwrap();
            assert!(keybytes.len() == 4 && valuebytes.len() == 4);
            let iso = u32::from_be_bytes((*keybytes).try_into().unwrap()) as usize;
            let distid = u32::from_be_bytes((*valuebytes).try_into().unwrap()) as usize;
            nriso2distid[iso] = distid;
        }
        assert_eq!(nriso2distid.pop().unwrap(), distsize);
        let unique: HashSet<_> = nriso2distid.iter().cloned().collect();
        assert_eq!(unique.len(), distsize);

        // 返回结果
        (nriso2distid, distsize)
    }

    fn load_dist_with_id(street:usize) -> Vec<[i64; 3]>{
        // 确定数据库、列族，以及打开数据库
        let path = std::format!("data/{}", T::GAME_NAME);
        let options = {
            let mut options = Options::default();
            options.create_if_missing(false);
            options
        };
        let nrid2dist_cf = format!(
            "{}_nrid_{}_to_winning_distribution",
            T::GAME_NAME,
            street+1
        );
        let cf_names = vec!["default", &nrid2dist_cf];
        let cf_descriptors: Vec<_> = cf_names
            .iter()
            .map(|cf_name| {
                let mut cf_opt = Options::default();
                cf_opt.create_if_missing(false);
                ColumnFamilyDescriptor::new(cf_name.clone(), cf_opt)
            })
            .collect();
        let db = DBCommon::<SingleThreaded,_>::open_cf_descriptors_read_only(&options, &path , cf_descriptors, false).expect(&format!("打不开这个数据库{}的列族{}", path, nrid2dist_cf));
        let cf_handle = db.cf_handle(&nrid2dist_cf).expect(&format!("没有这个列族:{}", nrid2dist_cf));
        
        // 把所有distid读出来，并且校验（校验的逻辑是最后一位是否是distsize，把所有数据去重之后长度是否为distsize）
        let mut dbcf_iter = db.iterator_cf(cf_handle, IteratorMode::Start);
        let mut nrid2dist: Vec<[i64; 3]> = vec![];
        for item in dbcf_iter {
            let (keybytes, valuebytes) = item.unwrap();
            assert!(keybytes.len() == 4);
            let distid = u32::from_be_bytes((*keybytes).try_into().unwrap()) as usize;
            let dist: [i64; 3] = bincode::deserialize(&valuebytes).unwrap();
            nrid2dist.push(dist);
            assert_eq!(distid, nrid2dist.len()-1);
        }
        
        nrid2dist
    }

    pub fn save(&self, custom_name: &str) {
        let path_str = std::format!("data/CustomCluster/{}/Ehs/{}", T::GAME_NAME, custom_name);
        let path = Path::new(path_str.as_str());
        if path.exists() {
            fs::remove_dir_all(path).unwrap();
        }
        fs::create_dir_all(path).unwrap();

        // 创建一个选项对象
        let mut opts = Options::default();
        // 设置数据库如果不存在则创建
        opts.create_if_missing(true);
        let mut db = DB::open(&opts, path).unwrap();
     
        self.kriso2bucket_street
            .iter()
            .zip(self.bucket_size_street.iter())
            .enumerate()
            .for_each(|(street, (kriso2bucket, bucket_size))|{
                let recall_from = if let AbstractAlgorithmStreet::Ehs {..} = &self.abstr_configs[street] {
                    street
                }
                else if let AbstractAlgorithmStreet::Isomorphism{recall_from} = self.abstr_configs[street] {
                    recall_from
                }
                else {
                    panic!("不能到这里");
                };
                assert!(street >= recall_from, "street: {}, recall_from: {}", street, recall_from);
                save_cluster_bucket_street::<T>(&mut db, street, recall_from, kriso2bucket.as_ref(), *bucket_size);
            });
        save_cluster_configs_yaml(path, self.abstr_configs.as_ref());
    }
}

