use std::collections::HashSet;

use crate::faster::{DynamicCorest, TreeData};
use crate::faster::common::*;
use crate::CoresetError;
use faer::sparse::{
    csr_symbolic::generic::SymbolicSparseRowMat, SparseColMat, SparseRowMat, SymbolicSparseColMat};


use faer::traits::num_traits::real::Real;
use rand::{rngs::StdRng, Rng};
use rayon::prelude::*;

use itertools::izip;
use rustc_hash::{FxHashMap, FxHashSet};

// MARK: CoresetInfo struct
pub struct CoresetInfo{
    pub coreset_size: usize,
    pub sampling_seeds: usize,
    pub sigma: Float,
    pub x_star_index: ShiftedIndex,
    pub x_star_degree: NodeDegree,
    pub x_star_term: Float,
    pub generation: usize,
}

impl <const ARITY: usize> DynamicCorest<ARITY>
where ConstPow2<ARITY>: PowerOfTwo
{


    pub fn build_coreset_graph(
        &self,
        coreset_indices: &[NodeIdentity],
        coreset_weights: &[Float],
        coreset_info: &CoresetInfo)
        -> Result<SparseRowMat<usize,Float>,DynamicCoresetError> {

        let n = coreset_indices.len();
        let shift = coreset_info.sigma;

        let degrees = coreset_indices.iter().map(|idx|{
            self.degrees.get(idx).unwrap().1.clone()
        }).collect::<Vec<_>>();

        let node_name_to_index = coreset_indices
        .iter()
        .enumerate()
        .map(|(idx, name)| (*name, idx))
        .collect::<FxHashMap<NodeIdentity, usize>>();

        let W_D_inv = (0..n).into_iter()
        .map(|idx|{
            coreset_weights[idx]/degrees[idx].0
        }).collect::<Vec<Float>>();

        // guess the number of non-zero entries in the coreset graph:
        let mut data = Vec::<Float>::with_capacity(n*200);
        let mut indices = Vec::<usize>::with_capacity(n*200);
        let mut indptr = Vec::<usize>::with_capacity(n+1);
        let mut nnz_per_row = Vec::<usize>::with_capacity(n);

        let mut indptr_counter = 0;
        for (i, &node_name) in coreset_indices.iter().enumerate(){
            let neighbours = self.adjacency.get(&node_name).unwrap();
    
            // get the neighbours of index that are in the coreset and transform the data
            // We are computing
            // A_C = W_CD^{-1}_C A_C D^{-1}_C W_C + W_C shift*D^{-1}_C W_C
            //     = W_CD^{-1}_C A_C D^{-1}_C W_C + shift* W_C*D^{-1}_C W_C
            // where:
            //  -A_C is the submatrix of A corresponding to the coreset indices,
            //  -W_C is the diagonal matrix of coreset weights,
            //  -D is the diagonal matrix of A and D_C is the submatrix of D corresponding to the coreset indices.
            let W_D_inv_i = W_D_inv[i];
            let mut good_indices_and_data_transformed = neighbours.iter().filter_map(|(&neighbour_name,&data)|{
                if node_name == neighbour_name{
                    node_name_to_index.get(&neighbour_name).map(|&coreset_j|{
                    (coreset_j, data.0*W_D_inv_i*W_D_inv[i] + shift*(coreset_weights[i].0)*W_D_inv_i)
                })}
                else{
                    node_name_to_index.get(&neighbour_name).map(|&coreset_j|{
                    (coreset_j, data.0*W_D_inv_i*W_D_inv[coreset_j])
                })}
    
            }).collect::<Vec<(usize,Float)>>();

            good_indices_and_data_transformed.sort_unstable_by_key(|&(idx,_)| idx);

            // push the data and indices to the data and indices vectors:
            data.extend(good_indices_and_data_transformed.iter().map(|x| x.1));
            indices.extend(good_indices_and_data_transformed.iter().map(|x| x.0));
            let nnz = good_indices_and_data_transformed.len();
            nnz_per_row.push(nnz);
            // push the indptr counter to the indptr vector and bump by nnz
            indptr.push(indptr_counter);
            indptr_counter += nnz;
        }
        // push the last indptr counter to the indptr vector
        indptr.push(indptr_counter);
        Ok(SparseRowMat::new(
            SymbolicSparseRowMat::<faer::sparse::csr_symbolic::Own<_,_,_>>::new_checked(
                n,
                n,
                indptr,
                Some(nnz_per_row), 
                indices,
            ),
            data
        ))
    }

    //MARK: Coreset Extraction
    pub fn extract_coreset_graph(&mut self, coreset_size: usize, sampling_seeds: usize) 
    -> Result<(Vec<String>, Vec<Float>, SparseRowMat<usize,Float>, SamplingStats), CoresetError>{

        let shift = self.shift;

        let (&x_star_name, x_star_degree) = self.degrees.peek().ok_or(CoresetError::NoData)?;
        let x_star_index = self.node_location_map.get(&x_star_name).ok_or(CoresetError::NoData)?;
        let x_star_term = Self::compute_x_star_term(shift, *x_star_degree);
        
        let coreset_info = CoresetInfo{
            coreset_size,
            sampling_seeds,
            sigma: shift,
            x_star_index: *x_star_index,
            x_star_degree: *x_star_degree,
            x_star_term,
            generation: self.node_generation_counter,
        };

        let mut total_init_weight = *x_star_degree;

        // Sample the first point
        let mut full_stats = self.sample_first_point(x_star_name, &coreset_info);

        let mut rng = rand::SeedableRng::from_os_rng();

        for _ in 0..sampling_seeds-1{
            // Sample a point from the current coreset
            let (node_id, _) = self.sample(&coreset_info, Float::from(0.0).into(), Float::from(0.0).into(), &mut rng)
                .map_err(|_| CoresetError::NoData)?;

            // Update the total init weight
            total_init_weight += convert(*self.degrees.get(&node_id).unwrap().1);

            // Repair the coreset with the new point
            let stats = self.repair(node_id, &coreset_info);

            full_stats += stats;
        }

        // Now we sample the coreset according to the smoothed contributions
        let total_contribution = self.contribution(ShiftedIndex(0), &coreset_info);
        let total_weight = self.tree_data.volumes[0].clone();

        let (coreset_indices, weights): (Vec<NodeIdentity>, Vec<Float>) = (0..coreset_size).map(|_|{
            let (node_id, prob) = self.sample_smoothed(
                &coreset_info, 
                total_init_weight.into(), 
                convert(total_weight),
                total_contribution, 
                &mut rng
            ).unwrap();
            let node_degree = self.degrees.get(&node_id).unwrap().1.into_float();
            let weight = node_degree/(prob * Float::from(coreset_size as Float_Dtype));
            (node_id, weight)
        }).unzip();

        // Now we deduplicate the coreset indices and weights:
        let mut unique_indices = Vec::new();
        let mut unique_weights = Vec::new();

        for (index, weight) in izip!(coreset_indices, weights) {
            if let Some(pos) = unique_indices.iter().position(|x| *x == index) {
                unique_weights[pos] += weight;
            } else {
                unique_indices.push(index);
                unique_weights.push(weight);
            }
        }

        let coreset_graph = self.build_coreset_graph(
            &unique_indices, &unique_weights, &coreset_info).unwrap();


        // update the pid based on the sampling stats
        self.update_filtered_average_dist_error(&full_stats);

        Ok((
            unique_indices
                .into_iter()
                .map(|idx| self.string_map_reverse.get(&idx).unwrap().clone()).collect(),
            unique_weights, 
            coreset_graph,
            full_stats
        ))
    }



    // MARK: Labeling
    pub fn rust_label_full_graph(&self,
        coreset_names: &[String],
        coreset_weights: &[Float_Dtype],
        coreset_labels: &[usize],
        num_clusters: usize,
    ) -> (Vec<NodeIdentity>, Vec<usize>, Vec<Float_Dtype>){

        let coreset_indices = coreset_names.iter()
            .map(|name| self.string_map.get(name.as_str()).unwrap().clone())
            .collect::<Vec<_>>();

        let n = self.adjacency.len();
        let coreset_size = coreset_indices.len();
        let shift = self.shift;
        debug_assert!(coreset_size == coreset_weights.len() && coreset_size == coreset_labels.len());

        // group the coreset and coreset weights by label and process each group in parallel:
        let coreset_grouped = coreset_indices.iter().zip(coreset_labels).zip(coreset_weights).fold(
            vec![(vec![],vec![]);num_clusters], |mut acc, ((&i, &label), &weight)|{
                acc[label].0.push(i);
                acc[label].1.push(weight);
                acc
            }
        );

        // Now we compute the center norms and center denoms for each cluster
        let result = coreset_grouped.into_par_iter().enumerate().map(|(_, (indices,weights))|{
            // return zero if the cluster is empty
            if indices.is_empty(){
                return (Float::from(0.0),From::from(0.0));
            }

            let indices_set: HashSet<&NodeIdentity> = indices.iter().collect();
            let index_to_weight: FxHashMap<_,_> = indices.iter().zip(weights.iter()).collect();

            // compute the denominator:
            let denom: Float = weights.iter().sum::<Float_Dtype>().into();
            // compute the center norm sum
            let mut center_norm_sum: Float = Float::from(0.0);
            indices.iter().for_each(|i|{
                let weight = index_to_weight[i];
                let neighbour_indices = self.adjacency[i].keys();
                let neighbour_values = self.adjacency[i].iter().map(|(j,v)|{
                    if i!=j{
                        v.into_float()/(self.degrees.get(i).unwrap().1.into_float()* self.degrees.get(j).unwrap().1.into_float())
                    }else{
                        v.into_float()/(
                            self.degrees.get(i).unwrap().1.into_float()*self.degrees.get(j).unwrap().1.into_float()
                        ) + shift/(self.degrees.get(i).unwrap().1.into_float())
                    }

                });
                center_norm_sum += neighbour_indices.zip(neighbour_values).fold(
                    Float::from(0.0), |acc, (j, value)|{
                        match indices_set.contains(&j){
                            true => acc + value*weight*index_to_weight[&j],
                            false => acc
                        }
                    });
            });
            center_norm_sum /= denom*denom;
            (center_norm_sum,denom)
        }).collect::<Vec<(Float,Float)>>();

        let (center_norms, center_denoms): (Vec<Float>,Vec<Float>) = result.into_iter().unzip();

            // Now find the cluster with the smallest center norm - this will be the "default" cluster

        let smallest_center_by_norm = center_norms.iter().enumerate()
        .min_by(|(_,a),(_,b)| a.partial_cmp(b).unwrap()).unwrap().0;
        let smallest_center_by_norm_value = center_norms[smallest_center_by_norm];

        // Now prepare to label everything in parallel:

        let coreset_set = coreset_indices.iter().collect::<FxHashSet<_>>();
        let label_map = coreset_indices.iter().zip(coreset_labels).collect::<FxHashMap<_,_>>();
        let weight_map = coreset_indices.iter().zip(coreset_weights).collect::<FxHashMap<_,_>>();

        let node_names = self.node_location_map.keys().into_iter().cloned().collect::<Vec<_>>();

        let labels_and_distances2: (Vec<usize>,Vec<Float_Dtype>) = node_names.par_iter().map(|i|{

            let vertex_degree = self.degrees.get(i).unwrap().1.into_float();
            // store the inner product to all the centers
            let mut x_to_c_is = FxHashMap::default();


            // let neighbour_indices = self.adjacency[i].keys();
            // let neighbour_edge_weights = self.adjacency[i].values();
            // let neighbour_values = adj_mat.values_of_row(i).iter().enumerate().map(|(j,v)|{
            //     v/(degree_vector[i]*degree_vector[j])
            // });

            self.adjacency[i].iter().for_each(|(indx,weight)|{
                if coreset_set.contains(&indx){
                        let label = label_map[&indx];
                        let neighbour_weight = weight_map[&indx];
                        let inner_prod_with_vertex = {
                            if i!=indx{
                                weight.into_float()/(vertex_degree*self.degrees.get(indx).unwrap().1.into_float())
                            }else{
                                weight.into_float()/(vertex_degree*self.degrees.get(indx).unwrap().1.into_float())
                                /(vertex_degree*self.degrees.get(indx).unwrap().1.into_float()) + shift/(vertex_degree)
                            }
                        };
                        x_to_c_is.entry(label).and_modify(|e|{
                            *e += Float::from(*neighbour_weight)*inner_prod_with_vertex;
                        }).or_insert(Float::from(*neighbour_weight)*inner_prod_with_vertex);
                    }
            });

            // normalize the inner products to each cluster by each center denominator
            x_to_c_is.iter_mut().for_each(|(k,v)| *v /= center_denoms[**k]);

            let mut best_center = smallest_center_by_norm;
            let mut best_center_value = smallest_center_by_norm_value;

            x_to_c_is.iter().for_each(|(center,v)|{
                // right now v is just the inner product to each center, not the distance

                // When we compute the (smallest) distance, we can ignore the contribution of the vertex
                let distance = center_norms[**center] - Float::from(2.0)*v;
                if distance < best_center_value{
                    best_center = **center;
                    best_center_value = distance;
                }
            });
            (best_center,(best_center_value + Float::from(1.0)/(vertex_degree*vertex_degree) + shift/(vertex_degree)).0)
        }).unzip();
        
        (
            node_names,
            labels_and_distances2.0,
            labels_and_distances2.1
        )
    }

}