use faer::sparse::SymbolicSparseColMatRef;
use faer::{unzipped, zipped, MatRef,mat::from_raw_parts};
// use faer::linalg::matmul::{inner_prod,matmul};
use ndarray::{Array1, ArrayView2, AssignElem};
use pyo3::prelude::*;
use numpy::ndarray::ArrayView1;
use numpy::{IntoPyArray, PyArray1, PyReadonlyArray1, PyReadonlyArray2};
use pyo3::types::PyTuple;
use pyo3::{
    pymodule,
    types::PyModule,
    Bound, PyResult,
};

use faer::Mat;
use faer::sparse::SparseColMatRef;


use std::collections::{HashMap, HashSet};
use std::error::Error;
use std::fs::File;
use std::panic;
// use std::error::Error;
use ndarray_npz::NpzReader;

mod non_incident_tree;
mod incident_tree;
mod common;
mod full;

mod fixed;

use fixed::common::Error as SamplingError;



#[allow(non_snake_case)]
pub fn load_sparse_matrix(file_path: &str) -> Result<(Vec<f64>, Vec<usize>, Vec<usize>, Vec<usize>, Vec<f64>, (usize, usize), usize), Box<dyn Error>> {
    let file = File::open(file_path)?;
    let mut npz = NpzReader::new(file)?;


    // Load the arrays from the npz file
    let data: Array1<f64> = npz.by_name("data.npy").expect("Could not load data.npy");
    let indices: Array1<u64> = npz.by_name("indices.npy")?;
    let indices = indices.mapv(|x| x as usize);
    let indptr: Array1<u64> = npz.by_name("indptr.npy")?;
    let indptr = indptr.mapv(|x| x as usize);
    let nnz_per_col: Array1<u64> = npz.by_name("nnz_per_col.npy")?;
    let nnz_per_col = nnz_per_col.mapv(|x| x as usize);
    let W: Array1<f64> = npz.by_name("W.npy")?;
    let shape: Array1<u64> = npz.by_name("shape.npy")?;
    let num_clusters: Array1<u64> = npz.by_name("num_clusters.npy")?;
    // Convert the arrays to Vec
    let data_vec = data.to_vec();
    let indices_vec = indices.to_vec();
    let indptr_vec = indptr.to_vec();
    let nnz_per_col = nnz_per_col.to_vec();

    // Extract the shape
    let shape_vec = shape.to_vec();
    let shape_tuple = (shape_vec[0] as usize, shape_vec[1] as usize);

    // Extract the number of clusters
    let num_clusters_val = num_clusters[0] as usize;

    Ok((data_vec, indices_vec, indptr_vec, nnz_per_col, W.to_vec(), shape_tuple, num_clusters_val))
}

fn data_indices_indptr_to_sparse_mat_ref<'a ,E>(
    n: usize, 
    data: &'a [E], 
    indices: &'a [usize], 
    indptr: &'a [usize],
    nnz_per_col_as_slice: &'a [usize]) -> SparseColMatRef::<'a, usize, E>

where
    E: faer::Entity + faer::SimpleEntity,
     {

    let symbolic_sparse_mat_ref = 
         SymbolicSparseColMatRef::new_checked(
            n,
            n,
            indptr,
            Some(nnz_per_col_as_slice),
            indices
        );

    SparseColMatRef::new(
        symbolic_sparse_mat_ref,
        data
    )
}

#[allow(non_snake_case)]
pub fn coreset(
    clusters: usize,
    n: usize, 
    coreset_size: usize,
    data: ArrayView1<'_, f64>, 
    indices: ArrayView1<'_, usize>, 
    indptr: ArrayView1<'_, usize>,
    nnz_per_col: ArrayView1<'_, usize>,
    W: ArrayView1<'_, f64>) -> (Array1<usize>, Array1<f64>) {

    
    let data_as_slice = data.as_slice().unwrap();
    let indices_as_slice = indices.as_slice().unwrap();
    let indptr_as_slice = indptr.as_slice().unwrap();
    let nnz_per_col_as_slice = nnz_per_col.as_slice().unwrap();

    let K: SparseColMatRef<'_, usize, f64> = data_indices_indptr_to_sparse_mat_ref(n, data_as_slice, indices_as_slice, indptr_as_slice, nnz_per_col_as_slice);
    let W_as_slice: &[f64] = W.as_slice().unwrap();
    // let W_faer = Mat::from_fn(n, 1, |i,_| W_as_slice[i]);
    let W_faer = unsafe{from_raw_parts::<f64>(W_as_slice.as_ptr(),n, 1,1,1)};

    let mut coreset_sampler = full::SamplingTree::initialize(clusters, coreset_size, K, W_faer.as_ref());
    let (coreset, coreset_weights) = coreset_sampler.construct_coreset().unwrap();
    
    let coreset_array = Array1::<usize>::from_iter(
        coreset.into_iter().map(|x| x.0)
    );
    let coreset_weights_array = Array1::from_vec(coreset_weights);

    (coreset_array, coreset_weights_array)
}

#[allow(non_snake_case)]
pub fn improved_coreset(
    clusters: usize,
    n: usize, 
    coreset_size: usize,
    data: ArrayView1<'_, f64>, 
    indices: ArrayView1<'_, usize>, 
    indptr: ArrayView1<'_, usize>,
    nnz_per_col: ArrayView1<'_, usize>,
    W: ArrayView1<'_, f64>,
    K_W_in_graph_form: Option<bool>) -> Result<(Array1<usize>, Array1<f64>),SamplingError> {

    
    let data_as_slice = data.as_slice().unwrap();
    let indices_as_slice = indices.as_slice().unwrap();
    let indptr_as_slice = indptr.as_slice().unwrap();
    let nnz_per_col_as_slice = nnz_per_col.as_slice().unwrap();

    let K = data_indices_indptr_to_sparse_mat_ref(n, data_as_slice, indices_as_slice, indptr_as_slice, nnz_per_col_as_slice);
    let W_as_slice = W.as_slice().unwrap();
    // let W_faer = Mat::from_fn(n, 1, |i,_| W_as_slice[i]);
    let W_faer = unsafe{from_raw_parts(W_as_slice.as_ptr(),n, 1,1,1)};



    let mut coreset_sampler = fixed::full::FullSamplingTree::initialize(clusters, coreset_size, K, W_faer.as_ref(), K_W_in_graph_form.unwrap_or(false));
    

    let result = coreset_sampler.construct_coreset();
    match result{
        Ok((coreset, coreset_weights)) => {
            let coreset_array = Array1::<usize>::from_iter(
                coreset.into_iter().map(|x| x.0)
            );
            let coreset_weights_array = Array1::from_vec(coreset_weights);
            Ok((coreset_array, coreset_weights_array))
        },
        Err(e) => Err(e)
    }
}


#[allow(non_snake_case)]
pub fn dense_column(K: SparseColMatRef<usize,f64>, col: usize) -> Mat<f64>{
    let n = K.nrows();
    let sparse_col_indices = K.row_indices_of_col(col);
    let sparse_col_values = K.values_of_col(col);
    let mut dense_col = Mat::from_fn(n, 1, |_,_| 0.0);

    for (v,i) in sparse_col_values.iter().zip(sparse_col_indices){
        dense_col[(i,0)] = *v;
    }

    dense_col
}

pub fn cum_dist(probs: MatRef<f64>) -> Mat<f64>{
    let mut cum_dist = Mat::zeros(probs.nrows(), 1);
    cum_dist[(0,0)] = probs[(0,0)];

    for i in 1..probs.nrows(){
        cum_dist[(i,0)] = cum_dist[(i-1,0)] + probs[(i,0)];
    }
    cum_dist
}

pub fn sample_from_dist(probs: MatRef<f64>, rand_val: f64) -> usize{
    let cum_dist = cum_dist(probs);
    let idx = (cum_dist.col_as_slice(0).binary_search_by(|x| x.partial_cmp(&rand_val).unwrap()).map_or_else(|x| x, |x| x) as i64 -1).max(0) as usize;
    idx
}


pub fn sample_from_dist_with_replacement(probs: MatRef<f64>, samples: usize) -> Vec<usize>{
    let cum_dist = cum_dist(probs);
    let mut indices = Vec::with_capacity(samples);
    (0..samples).for_each(|_| {
        let rand_val = rand::random::<f64>();
        let idx = (cum_dist.col_as_slice(0).binary_search_by(|x| x.partial_cmp(&rand_val).unwrap()).map_or_else(|x| x, |x| x) as i64 -1).max(0) as usize;
        indices.push(idx);
    });
    indices
}

#[allow(non_snake_case)]
pub fn d_z_sampling(K: SparseColMatRef<usize,f64>, W: MatRef<f64>, k: usize, graph_form: bool) -> (Vec<usize>, Mat<f64>){
    let n = K.nrows();

    let mut coreset_indices: Vec<usize> = vec!(0;k);
    let mut K_to_current_coreset_dense: Mat<f64> = Mat::zeros(n, 1);
    let K_diagonal = match graph_form{
        false =>Mat::from_fn(n, 1, |i,_| K[(i,i)]),
        true => zipped!(&W).map(|unzipped!(w)| 1.0/(*w * *w))
    };
    


    let rand_index = rand::random::<usize>() % n;

    coreset_indices[0] = rand_index;
    K_to_current_coreset_dense.col_mut(0).assign_elem(dense_column(K, rand_index).col_mut(0));

    let mut dist_squared_vector = Mat::from_fn(n, 1, |i,_| K_diagonal[(i,0)] + K_diagonal[(rand_index,0)] - 2.0 *K_to_current_coreset_dense[(i,0)]);

    for i in 1..k{
        let mut probs = Mat::from_fn(n, 1, |i,_| dist_squared_vector[(i,0)] * W[(i,0)]);
        let cost = probs.sum();
        probs = zipped!(&mut probs).map(|unzipped!(p)| *p /cost);

        // sample the next index from the distribution
        let rand_val = rand::random::<f64>();
        let next_index = sample_from_dist(probs.as_ref(), rand_val);
        coreset_indices[i] = next_index;
        K_to_current_coreset_dense.col_mut(0).assign_elem(dense_column(K, next_index).col_mut(0));

        // update the distance squared vector in place
        zipped!(dist_squared_vector.col_mut(0), &K_to_current_coreset_dense.col(0), &K_diagonal.col(0)).for_each(|unzipped!(mut d,k_to_c, k_diag)| d.write(d.min(K_diagonal[(next_index,0)] + *k_diag - 2.0* *k_to_c )));
    }
    (coreset_indices, dist_squared_vector)
}

#[allow(non_snake_case)]
pub fn old_coreset(
    clusters: usize,
    n: usize, 
    coreset_size: usize,
    data: ArrayView1<'_, f64>, 
    indices: ArrayView1<'_, usize>, 
    indptr: ArrayView1<'_, usize>,
    nnz_per_col: ArrayView1<'_, usize>,
    W: ArrayView1<'_, f64>,
    graph_from: bool) -> (Array1<usize>, Array1<f64>) {

    
    let data_as_slice = data.as_slice().unwrap();
    let indices_as_slice = indices.as_slice().unwrap();
    let indptr_as_slice = indptr.as_slice().unwrap();
    let nnz_per_col_as_slice = nnz_per_col.as_slice().unwrap();

    let K = data_indices_indptr_to_sparse_mat_ref(n, data_as_slice, indices_as_slice, indptr_as_slice, nnz_per_col_as_slice);
    let W_as_slice = W.as_slice().unwrap();
    let W_faer: MatRef<f64> = unsafe{from_raw_parts(W_as_slice.as_ptr(),n, 1,1,1)};

    
    let (coreset_init, dist_squared_vector) = d_z_sampling(K, W_faer, clusters, graph_from);

    let mut probs = Mat::from_fn(n, 1, |i,_| dist_squared_vector[(i,0)] * W_faer[(i,0)]);
    let cost = probs.sum();
    probs = zipped!(&mut probs).map(|unzipped!(p)| *p /cost);
    let weight_c: f64 = coreset_init.iter().map(|i| W_faer[(*i,0)]).sum();

    let mut sigmas  = zipped!(&probs, &W_faer).map(|unzipped!(p,w)| *p + *w/weight_c);
    let sigma_sum = sigmas.sum();
    sigmas = zipped!(&mut sigmas).map(|unzipped!(s)| *s / sigma_sum);

    let actual_coreset = sample_from_dist_with_replacement(sigmas.as_ref(), coreset_size);
    let mut coreset_weights = vec!(0.0;coreset_size);
    for i in 0..coreset_size{
        let index = actual_coreset[i];
        let weight = W_faer[(index,0)];
        let sigma = sigmas[(index,0)];
        coreset_weights[i] = weight/((coreset_size as f64) * sigma);
    }

    
    let (coreset, coreset_weights): (Vec<usize>, Vec<f64>) = (actual_coreset, coreset_weights);
    
    let coreset_array = Array1::<usize>::from_iter(
        coreset.into_iter()
    );
    let coreset_weights_array = Array1::from_vec(coreset_weights);

    (coreset_array, coreset_weights_array)
}

fn gaussian_kernel_single(x: ArrayView1<'_, f64>, y: ArrayView1<'_, f64>, gamma: f64) -> f64 {
    let diff = &x - &y;
    let norm = diff.dot(&diff);
    (-gamma * norm).exp()
}

fn gaussian_kernel_n_to_1(xs: ArrayView2<'_, f64>, y: ArrayView1<'_, f64>, gamma: f64) -> Array1<f64> {
    xs.outer_iter().map(|x| gaussian_kernel_single(x, y, gamma)).collect()
}

#[allow(non_snake_case)]
pub fn lazy_d_z_sampling(
    data: ArrayView2<'_, f64>, // 2D array (n by d)
    W: &[f64], 
    k: usize, 
    gamma: f64
) -> (Vec<usize>, Vec<f64>) {
    let n = data.nrows();
    let d = data.ncols();
    
    let mut coreset_indices: Vec<usize> = vec![0; k];
    let mut K_to_current_coreset_dense: Vec<f64> = vec![0.0; n];

    // Kernel is unnormalized, so we need to divide by the weights K(x,y) -> K(x,y)/(W_x * W_y)

    let K_diagonal: Vec<f64> = (0..n).map(|i| gaussian_kernel_single(data.row(i), data.row(i), gamma)/(W[i]*W[i])).collect();

    // Initialize with a random index
    let rand_index: usize = rand::random::<usize>() % n;
    coreset_indices[0] = rand_index;

    // Initialize the first column of K_to_current_coreset_dense
    K_to_current_coreset_dense = gaussian_kernel_n_to_1(data, data.row(rand_index), gamma).to_vec();
    // Kernel is unnormalized, so we need to divide by the weights K(x,y) -> K(x,y)/(W_x * W_y)
    K_to_current_coreset_dense = K_to_current_coreset_dense.iter().zip(W).map(|(k, w)| k / (w*W[rand_index])).collect();


    // Initialize the distance squared vector
    let mut dist_squared_vector: Vec<f64> = (0..n)
        .map(|i| {
            K_diagonal[i] + K_diagonal[rand_index] - 2.0 * K_to_current_coreset_dense[i]
        })
        .collect();

    // Iterate to build the coreset
    for i in 1..k {
        // Compute probability distribution based on distances and weights
        let mut probs: Vec<f64> = (0..n)
            .map(|i| dist_squared_vector[i] * W[i])
            .collect();
        let cost: f64 = probs.iter().sum();
        probs.iter_mut().for_each(|p| *p /= cost);

        // Sample the next index from the distribution
        let rand_val = rand::random::<f64>();
        let next_index = sample_from_dist(unsafe{from_raw_parts(probs.as_ptr(), probs.len(), 1, 1, 1)}, rand_val);
        coreset_indices[i] = next_index;

        // Update the K_to_current_coreset_dense for the new coreset point
        K_to_current_coreset_dense = gaussian_kernel_n_to_1(data, data.row(next_index), gamma).to_vec();
        K_to_current_coreset_dense = K_to_current_coreset_dense.iter().zip(W).map(|(k, w)| k / (w*W[next_index])).collect();

        // Update the distance squared vector in place
        for i in 0..n {
            dist_squared_vector[i] = dist_squared_vector[i].min(
                K_diagonal[next_index] + K_diagonal[i] - 2.0 * K_to_current_coreset_dense[i],
            );
        }
    }

    (coreset_indices, dist_squared_vector)
}


#[allow(non_snake_case)]
pub fn lazy_old_coreset(
    clusters: usize,
    n: usize, 
    coreset_size: usize,
    data: ArrayView2<'_, f64>, // 2D Array for data (n by d)
    W: ArrayView1<'_, f64>, // Weight vector
    gamma: f64) -> (Array1<usize>, Array1<f64>) {
    
    let W_as_slice = W.as_slice().unwrap();
    let W_faer: MatRef<f64> = unsafe{from_raw_parts(W_as_slice.as_ptr(),n, 1,1,1)};

    let (coreset_init, dist_squared_vector) = lazy_d_z_sampling(data, W_as_slice, clusters, gamma);


    let mut probs = Mat::from_fn(n, 1, |i,_| dist_squared_vector[i] * W_faer[(i,0)]);
    let cost = probs.sum();
    probs = zipped!(&mut probs).map(|unzipped!(p)| *p /cost);
    let weight_c: f64 = coreset_init.iter().map(|i| W_faer[(*i,0)]).sum();

    let mut sigmas  = zipped!(&probs, &W_faer).map(|unzipped!(p,w)| *p + *w/weight_c);
    let sigma_sum = sigmas.sum();
    sigmas = zipped!(&mut sigmas).map(|unzipped!(s)| *s / sigma_sum);

    let actual_coreset = sample_from_dist_with_replacement(sigmas.as_ref(), coreset_size);
    let mut coreset_weights = vec!(0.0;coreset_size);
    for i in 0..coreset_size{
        let index = actual_coreset[i];
        let weight = W_faer[(index,0)];
        let sigma = sigmas[(index,0)];
        coreset_weights[i] = weight/((coreset_size as f64) * sigma);
    }

    
    let (coreset, coreset_weights): (Vec<usize>, Vec<f64>) = (actual_coreset, coreset_weights);
    
    let coreset_array = Array1::<usize>::from_iter(
        coreset.into_iter()
    );
    let coreset_weights_array = Array1::from_vec(coreset_weights);

    (coreset_array, coreset_weights_array)

    
}







/// A Python module implemented in Rust.
#[allow(non_snake_case)]
#[pymodule]
fn fast_kernel_coreset_sampling(m: &Bound<'_, PyModule>) -> PyResult<()> {


    #[pyfn(m)]
    #[pyo3(name = "old_coreset")]
    fn old_coreset_py<'py>(
        py: Python<'py>,
        clusters: usize,
        n: usize,
        coreset_size: usize,
        data: PyReadonlyArray1<'py, f64>,
        indices: PyReadonlyArray1<'py, usize>,
        indptr: PyReadonlyArray1<'py, usize>,
        nnz_per_col: PyReadonlyArray1<'py, usize>,
        W: PyReadonlyArray1<'py, f64>,
        graph_from: Option<bool>
    ) -> pyo3::Bound<'py, PyTuple>{
        let (coreset, coreset_weights) = old_coreset(
            clusters, n, coreset_size, data.as_array(), indices.as_array(), indptr.as_array(), nnz_per_col.as_array(), W.as_array(),
            graph_from.unwrap_or(false));
        let coreset_py = coreset.into_pyarray_bound(py);
        let coreset_weights_py = coreset_weights.into_pyarray_bound(py);
        let tuple = PyTuple::new_bound(py, &[coreset_py.to_object(py), coreset_weights_py.to_object(py)]);
        tuple
    }




    #[pyfn(m)]
    #[pyo3(name = "coreset")]
    fn coreset_py<'py>(
        py: Python<'py>,
        clusters: usize,
        n: usize,
        coreset_size: usize,
        data: PyReadonlyArray1<'py, f64>,
        indices: PyReadonlyArray1<'py, usize>,
        indptr: PyReadonlyArray1<'py, usize>,
        nnz_per_col: PyReadonlyArray1<'py, usize>,
        W: PyReadonlyArray1<'py, f64>
    ) -> pyo3::Bound<'py, PyTuple>{
        let (coreset, coreset_weights) = coreset(clusters, n, coreset_size, data.as_array(), indices.as_array(), indptr.as_array(), nnz_per_col.as_array(), W.as_array());
        let coreset_py = coreset.into_pyarray_bound(py);
        let coreset_weights_py = coreset_weights.into_pyarray_bound(py);
        let tuple = PyTuple::new_bound(py, &[coreset_py.to_object(py), coreset_weights_py.to_object(py)]);
        tuple
    }

    #[pyfn(m)]
    #[pyo3(name = "improved_coreset")]
    fn improved_coreset_py<'py>(
        py: Python<'py>,
        clusters: usize,
        n: usize,
        coreset_size: usize,
        data: PyReadonlyArray1<'py, f64>,
        indices: PyReadonlyArray1<'py, usize>,
        indptr: PyReadonlyArray1<'py, usize>,
        nnz_per_col: PyReadonlyArray1<'py, usize>,
        W: PyReadonlyArray1<'py, f64>,
        K_W_in_graph_form: Option<bool>
    ) -> pyo3::Bound<'py, PyTuple>{

        let result = improved_coreset(clusters, n, coreset_size, data.as_array(), indices.as_array(), indptr.as_array(), nnz_per_col.as_array(), W.as_array(), K_W_in_graph_form);

        match result{
            Ok((coreset, coreset_weights)) => {
                let coreset_py = coreset.into_pyarray_bound(py);
                let coreset_weights_py = coreset_weights.into_pyarray_bound(py);
                let tuple = PyTuple::new_bound(py, &[coreset_py.to_object(py), coreset_weights_py.to_object(py)]);
                tuple
            },
            Err(e) => {
                let tuple = PyTuple::new_bound(py, &["Failure".to_string().to_object(py), e.to_string().to_object(py)]);
                tuple
            }
        }

    }


    #[pyfn(m)]
    #[pyo3(name = "lazy_old_coreset")]
    fn lazy_old_coreset_py<'py>(
        py: Python<'py>,
        clusters: usize,
        n: usize,
        coreset_size: usize,
        data: PyReadonlyArray2<'py, f64>,
        W: PyReadonlyArray1<'py, f64>,
        gamma: f64
    ) -> pyo3::Bound<'py, PyTuple>{
        let (coreset, coreset_weights) = lazy_old_coreset(
            clusters, n, coreset_size, data.as_array(), W.as_array(), gamma);
        let coreset_py = coreset.into_pyarray_bound(py);
        let coreset_weights_py = coreset_weights.into_pyarray_bound(py);
        let tuple = PyTuple::new_bound(py, &[coreset_py.to_object(py), coreset_weights_py.to_object(py)]);
        tuple
    }


    #[pyfn(m)]
    #[pyo3(name = "fast_distances_to_centers")]
    fn fast_distances_to_centers_py<'py>(
        py: Python<'py>,
        data: PyReadonlyArray1<'py, f64>,
        indices: PyReadonlyArray1<'py, usize>,
        indptr: PyReadonlyArray1<'py, usize>,
        nnz_per_col: PyReadonlyArray1<'py, usize>,
        coreset: PyReadonlyArray1<usize>,
        coreset_labels: PyReadonlyArray1<usize>,
        coreset_weights: PyReadonlyArray1<f64>,
        k: usize,
        n: usize
    ) -> pyo3::Bound<'py, PyArray1<usize>>{

        // First we construct the sparse matrix
        let data_as_slice = data.as_slice().unwrap();
        let indices_as_slice = indices.as_slice().unwrap();
        let indptr_as_slice = indptr.as_slice().unwrap();
        let nnz_per_col_as_slice = nnz_per_col.as_slice().unwrap();
    
        let K = data_indices_indptr_to_sparse_mat_ref(n, data_as_slice, indices_as_slice, indptr_as_slice, nnz_per_col_as_slice);

        let coreset_as_slice = coreset.as_slice().unwrap();
        let coreset_labels_as_slice = coreset_labels.as_slice().unwrap();
        let coreset_weights_as_slice = coreset_weights.as_slice().unwrap();


        // group the coreset and coreset weights by label and process each group

        let coreset_grouped = coreset_as_slice.iter().zip(coreset_labels_as_slice).zip(coreset_weights_as_slice).fold(
            vec![(vec![],vec![]);k], |mut acc, ((&i, &label), &weight)|{
                acc[label].0.push(i);
                acc[label].1.push(weight);
                acc
            }
        );


        // Now we compute the center norms and center denoms

        let result = coreset_grouped.iter().enumerate().map(|(label, (indices, weights))|{
            // return zero if the cluster is empty
            if indices.is_empty(){
                return (0.0,0.0);
            }
            let indices_set: HashSet<_> = indices.iter().collect();
            let index_to_weight: HashMap<_,_> = indices.iter().zip(weights.iter()).collect();

            // compute the denominator:
            let denom = weights.iter().sum::<f64>();

            let mut center_norm_sum = 0.0;
            
            for i in indices.iter(){
                let weight = index_to_weight[&i];
                let neighbour_indices = K.row_indices_of_col(*i);
                let neighbour_values = K.values_of_col(*i);
                center_norm_sum += neighbour_indices.zip(neighbour_values).fold(0.0f64, |acc, (j, &value)|{
                    match indices_set.contains(&j){
                        true => acc + value*weight*index_to_weight[&j],
                        false => acc
                    }
                });
            }
            center_norm_sum /= denom*denom;
            (center_norm_sum, denom)
        }).collect::<Vec<(f64,f64)>>();

        let (center_norms,denoms):(Vec<_>, Vec<_>) = result.into_iter().unzip();

        let smallest_center_by_norm = center_norms.iter().enumerate().min_by(|(_,a),(_,b)| a.partial_cmp(b).unwrap()).unwrap().0;
        let smallest_center_by_norm_value = center_norms[smallest_center_by_norm];


        let coreset_set = coreset_as_slice.iter().collect::<HashSet<_>>();
        let label_map = coreset_as_slice.iter().zip(coreset_labels_as_slice).collect::<HashMap<_,_>>();
        let weight_map = coreset_as_slice.iter().zip(coreset_weights_as_slice).collect::<HashMap<_,_>>();

        let labels: Vec<usize> = (0..n).map(|i|{
            let mut x_to_c_is = HashMap::new();

            let neighbour_indices = K.row_indices_of_col(i);
            let neighbour_values = K.values_of_col(i);

            neighbour_indices.zip(neighbour_values).for_each(|(indx,value)|{
                match coreset_set.contains(&indx){
                    true =>{
                        let label = label_map[&indx];
                        let weight = weight_map[&indx];
                        x_to_c_is.entry(label).and_modify(|e| *e += value*weight).or_insert(value*weight);
                    },
                    false => ()
                    }
                });
            x_to_c_is.iter_mut().for_each(|(k,v)| *v /= denoms[**k]);

            let mut best_center = smallest_center_by_norm;
            let mut best_center_value = smallest_center_by_norm_value;

            for (center, inner_p) in x_to_c_is.iter(){
                let value = center_norms[**center] - 2.0*inner_p;
                if value < best_center_value{
                    best_center = **center;
                    best_center_value = value;
                }
            }

            best_center
        }).collect();
        

        // // return dummy data for now:
        // let labels_array = Array1::<usize>::from_elem(n, 0);

        labels.into_pyarray_bound(py)


    }


    Ok(())
}


#[cfg(test)]
mod tests{
    use super::{load_sparse_matrix};
    use ndarray::Array1;

    use crate::improved_coreset;



    #[allow(non_snake_case)]
    #[test]
    fn test_coreset_from_file(){
        let path = "sparse_matrix.npz";
        let (data, indices, indptr, nnz_per_col, W ,shape, _) = load_sparse_matrix(path).unwrap();
        println!("loaded data");
        let n = shape.0;

        let clusters = 500;
        let coreset_size = 2000;
        let data = Array1::from(data);
        let indices = Array1::from(indices);
        let indptr = Array1::from(indptr);
        let nnz_per_col = Array1::from(nnz_per_col);
        let W = Array1::from(W);

        let t0 = std::time::Instant::now();
        let result = improved_coreset(clusters, n, coreset_size, data.view(), indices.view(), indptr.view(), nnz_per_col.view(), W.view(), Some(true));
        
        if let Ok((coreset, coreset_weights)) = result{
            println!("Coreset: {:?}", coreset);
            println!("Coreset Weights: {:?}", coreset_weights);
            let t1 = std::time::Instant::now();
            println!("Time: {:?}", t1.duration_since(t0));
        }else{
            println!("Error: {:?}", result.unwrap_err());
        }

    }
    
}