use std::f32::consts::E;
use std::ptr::NonNull;

use faer::sparse::*;
use faer::prelude::*;
use rand::{rngs::StdRng};

use crate::fixed::common::CoresetCrossTerm;
use crate::fixed::common::DatapointWithCoresetCrossTerm;
use crate::fixed::common::Error;
use crate::fixed::common::SelfAffinity;
use crate::fixed::common::Weight;
use crate::fixed::tree::SamplingTree;
use crate::fixed::common::{Index, Float};


use crate::fixed::tree;

use super::tree::TreeNode;





#[allow(unused)]
pub struct FullSamplingTree<'a>{
    sampling_tree: SamplingTree,
    node_pointers: Vec<NonNull<TreeNode>>,
    smallest_coreset_self_affinity: Float,
    number_of_clusters: usize,
    coreset_initialization_set_weight: Weight,
    coreset_size: usize,
    rng: StdRng,
    number_of_data_points: usize,
    kernel_matrix: SparseColMatRef<'a, usize, Float>,
    weights: MatRef::<'a, Float>,
    x_star_index: usize,
}

#[allow(unused)]
impl <'a> FullSamplingTree<'a>{
    pub fn initialize(number_of_clusters: usize, coreset_size: usize, kernel_matrix: SparseColMatRef<'a, usize, Float>,weights: MatRef::<'a, Float>, graph_form: bool)-> Self{
        // assert the kernel matrix is square and the weights are of the same size.
        // also assert that the weights are a vector.
        let dim = kernel_matrix.nrows();
        assert_eq!(dim, kernel_matrix.ncols());
        assert_eq!(dim, weights.nrows());
        assert_eq!(1, weights.ncols());
        
        assert!(coreset_size < dim);
        assert!(number_of_clusters <= coreset_size);

        let number_of_data_points = dim;
    
        // We need to extract the self-affinities of the datapoints and their weights.
        // get a view of the diagonal of the kernel matrix
        let mut sampling_tree = SamplingTree::new();
        // populate sampling tree
        let mut node_pointers = Vec::<NonNull<TreeNode>>::with_capacity(number_of_data_points);
        

        
        let (node_pointers,x_star_index,smallest_coreset_self_affinity) = match graph_form{
          false =>{

            // x_star_index is the index of the point with smallest diagonal entry in the kernel matrix.
            let (x_star_index,smallest_self_affinity) = (0..number_of_data_points).fold((0, kernel_matrix.get(0,0).unwrap()), |(i_min, min), i|{
                let self_affinity = kernel_matrix.get(i,i).unwrap();
                if self_affinity < min {(i, self_affinity)} else {(i_min, min)}
            });
            let smallest_self_affinity = *smallest_self_affinity;
            

            (sampling_tree.insert_from_iterator(
                (0..number_of_data_points).map(|i|{
                    let self_affinity = *kernel_matrix.get(i,i).unwrap();
                    let weight = *weights.get(i,0);
                    (Index(i), Weight(weight), SelfAffinity(self_affinity))
                })
                , number_of_data_points, smallest_self_affinity), x_star_index, smallest_self_affinity)
          },
          true =>{
            
            // x_star_index is the index of the point with smallest value of 1/w**2 (aka the largest weight)
            let (x_star_index,w) = weights.col(0).iter().enumerate().fold((0,0.0), |(i_max, max), (i, &x)| if x > max {(i, x)} else {(i_max, max)});
            let x_star_1_over_w = 1.0/w;
            let minimum_self_affinity = 1.0/(x_star_1_over_w*x_star_1_over_w);
            
            (sampling_tree.insert_from_iterator(
                (0..number_of_data_points).map(|i|{
                    let weight = *weights.get(i,0);
                    // If we know K is D^{-1}AD^{-1} and W=D then K[i,i] = 1/weight[i]^2 so we don't even need to look at the kernel matrix.
                    let self_affinity = 1.0/(weight*weight);
                    (Index(i), Weight(weight), SelfAffinity(self_affinity))
                })
                ,
                number_of_data_points,
                minimum_self_affinity),x_star_index, minimum_self_affinity)
          }
        };

        let number_of_clusters = number_of_clusters;
        
        FullSamplingTree{
            sampling_tree: sampling_tree,
            node_pointers,
            smallest_coreset_self_affinity,
            number_of_clusters,
            coreset_initialization_set_weight: Weight(0.0),
            coreset_size,
            rng: rand::SeedableRng::from_entropy(),
            number_of_data_points,
            kernel_matrix,
            weights,
            x_star_index,
        }
    }

    #[allow(unused)]
    pub fn with_rng(mut self, rng: StdRng)-> Self{
        self.rng = rng;
        self
    }


    pub fn contribution(&self)-> Float{
        self.sampling_tree.contribution()
    }
    pub fn smoothed_contribution(&self, cost: Float, coreset_initialization_set_weight: Float)-> Float{
        self.sampling_tree.smoothed_contribution(cost, coreset_initialization_set_weight)
    }


    pub fn sample(&mut self) -> Result<NonNull<TreeNode>,Error>{
        let x: Result<NonNull<TreeNode>,Error> = self.sampling_tree.sample_node(&mut self.rng).map(|node| node.0);
        x
    }

    pub fn sample_smooth_with_probs(&mut self) -> Result<(NonNull<TreeNode>,Float),Error>{
        // Sample a point according to the smooth probabilities.
        self.sampling_tree.sample_node_smoothed(self.coreset_initialization_set_weight.0, &mut self.rng)
    }

    pub fn update_tree(&mut self, mut node_pointer: NonNull<TreeNode>, delta: Float) -> (){
        self.sampling_tree.update_delta(node_pointer, delta);
    }


    

    pub fn sample_first_point_and_update(&mut self){
        // first point is x_star
        let first_point: Index = Index(self.x_star_index);
        let mut first_point_pointer = self.node_pointers[self.x_star_index];

        // add x_star to the coreset initialization set.
        let x_star_weight = *self.weights.get(self.x_star_index,0);
        assert!(x_star_weight == unsafe{first_point_pointer.as_ref()}.weight().0);
        let x_star_self_affinity = self.smallest_coreset_self_affinity;
        self.coreset_initialization_set_weight = Weight(x_star_weight);
        

        // set the contribution of x_star to zero.
        unsafe{
            self.sampling_tree.update_delta(first_point_pointer, 0.0);
        }
        
        // check neighbours:
        // get the neighbours of x_star
        self.kernel_matrix.row_indices_of_col(first_point.0).map(|i|Index(i)).for_each(|neighbour_index|{

            // if the neighbour is x_star, skip it since we have already updated it.
            if neighbour_index.0 == first_point.0{
                return;
            }


            // compute the distance between x_star and the neighbour.
            let cross_term = self.kernel_matrix.get(first_point.0, neighbour_index.0).unwrap();
            let neighbour_self_affinity = *self.kernel_matrix.get(neighbour_index.0, neighbour_index.0).unwrap();
            let distance2 = x_star_self_affinity + neighbour_self_affinity - 2.0*cross_term;
            
            // update the delta of the neighbour.
            let mut neighbour_pointer = self.node_pointers[neighbour_index.0];
            unsafe{
                self.sampling_tree.update_delta(neighbour_pointer, distance2);
            }
        });
    }


    pub fn seed_sample_first_k(&mut self) -> Result<(), Error>{

        for j in (0..self.number_of_clusters){
            self.sample_first_point_and_update();

            let mut node_pointer = match self.sample(){
                Ok(node_pointer) => node_pointer,
                Err(e) => return Err(e),
            };
            
            let node_index = unsafe{(node_pointer.as_mut()).index().0};



            let node_self_affinity = *self.kernel_matrix.get(node_index, node_index).unwrap();

            // add the node to the coreset initialization set.
            let node_weight = *self.weights.get(node_index,0);
            self.coreset_initialization_set_weight = Weight(self.coreset_initialization_set_weight.0 + node_weight);

            // set the contribution of the node to zero.
            unsafe{
                self.sampling_tree.update_delta(node_pointer, 0.0);
            }
            

            // check neighbours:
            // get the neighbours of the node
            self.kernel_matrix.row_indices_of_col(node_index).map(|i|Index(i)).for_each(|neighbour_index|{
                // compute the distance between the node and the neighbour.
                let cross_term = self.kernel_matrix.get(node_index, neighbour_index.0).unwrap();
                let neighbour_self_affinity = *self.kernel_matrix.get(neighbour_index.0, neighbour_index.0).unwrap();
                let distance2 = node_self_affinity + neighbour_self_affinity - 2.0*cross_term;
                
                // update the delta of the neighbour.
                let mut neighbour_pointer = self.node_pointers[neighbour_index.0];

                unsafe{
                    self.sampling_tree.update_delta(neighbour_pointer, distance2);
                }
        })
        }
        return Ok(());
    }
    
    pub fn sample_coreset(&mut self, n: usize) -> Result<(Vec<Index>,Vec<Float>), Error>{

        let mut coreset_indices = Vec::with_capacity(n);
        let mut coreset_weights = Vec::with_capacity(n);

        for _ in (0..n){
            let (leaf,prob) = match self.sample_smooth_with_probs(){
                Ok((leaf,prob)) => (leaf,prob),
                Err(e) => return Err(e),
            };


            let leaf = unsafe{leaf.as_ref()};
            let coreset_weight = leaf.weight().0/(prob*n as Float); 
            coreset_indices.push(leaf.index());
            coreset_weights.push(coreset_weight);
        }

        Ok((coreset_indices,coreset_weights))
    }

    pub fn construct_coreset(&mut self,) -> Result<(Vec<Index>,Vec<Float>), Error>{
        let first_k_success = self.seed_sample_first_k()?;
        self.sample_coreset(self.coreset_size)
    }

}





#[cfg(test)]
mod tests{
    use super::*;
    #[test]
    fn basic(){
        // create a sparse kernel matrix corresponding to the following dense matrix:
        // 1.0      0       0.4     0.8
        // 0        1.0     0.5     0
        // 0.4      0.5     1.0     0
        // 0.8      0       0       1.0 
        let kernel_matrix = SparseColMat::<usize, Float>::try_new_from_triplets(
            4,
            4,
            &[
                (0, 0, 1.0),
                (0, 2, 0.4),
                (0, 3, 0.8),
                (1, 1, 1.0),
                (1, 2, 0.5),
                (2, 2, 1.0),
                (2, 0, 0.4),
                (2, 1, 0.5),
                (3, 0, 0.8),
                (3, 3, 1.0),
            ],
        ).unwrap();

        let weights = mat![
            [1.0],
            [3.0],
            [1.0],
            [1.0]
        ];
        // display the kernel matrix nicely
        println!("{:?}", &kernel_matrix.to_dense());
        
        let mut sampling_tree = FullSamplingTree::initialize(2, 3, kernel_matrix.as_ref(), weights.as_ref(),false);
        sampling_tree.seed_sample_first_k();


        let (coreset_indices, coreset_weights) = sampling_tree.sample_coreset(3).unwrap();
        println!("{:?}", coreset_indices);
        println!("{:?}", coreset_weights);
    }
}
