use std::fmt::Formatter;
use std::fmt::Debug;
use rand::{rngs::ThreadRng, Rng};
use bumpalo::Bump;
use std::ptr::NonNull;
use crate::improved::common::{Index, Weight, SelfAffinity, Float, Error};

// MARK: - Node structs
#[derive(Debug)]
#[allow(dead_code)]
pub struct NonIncidentLeafNode{
    pub index: Index,
    pub weight: Weight,
    pub self_affinity: SelfAffinity,
    pub parent: Option<NonNull<TreeNode>>
}

pub struct InternalNode{
    weight: Float,
    weighted_self_affinity: Float,
    left_child: Option<(NonNull<TreeNode>,usize)>, // (child pointer, number of nodes in subtree rooted at child)
    right_child: Option<(NonNull<TreeNode>,usize)>,
    parent: Option<NonNull<TreeNode>>
}

impl InternalNode{

    #[allow(dead_code)]
    pub fn contribution(&self, smallest_coreset_self_affinity: Float) -> Float{
        self.weighted_self_affinity + self.weight*smallest_coreset_self_affinity 
    }

    #[allow(dead_code)]
    pub fn smoothed_contribution(&self, smallest_coreset_self_affinity: Float, cost: Float, coreset_star_weight: Float) -> Float{
        let contribution = self.contribution(smallest_coreset_self_affinity);
        contribution/cost + self.weight/coreset_star_weight
    }
    
}

// We actually don't need to implement drop because we are using a bump allocator.

// Need to implement drop for InternalNode to recursively drop the children.
// impl Drop for InternalNode{
//     fn drop(&mut self) {
//         if let Some((left_child_pointer, _)) = &self.left_child{
//             drop(unsafe{Box::from_raw(*left_child_pointer)});
//         }
//         if let Some((right_child_pointer, _)) = &self.right_child{
//             drop(unsafe{Box::from_raw(*right_child_pointer)});
//         }
//     }
// }


impl Debug for InternalNode{
    fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result {

        let left_child_string = match &self.left_child{
            None => None,
            Some((left_child_pointer, size)) => {
                let left_child = unsafe{left_child_pointer.as_ref()};
                Some((size, left_child))
            }
        };

        let right_child_string = match &self.right_child{
            None => None,
            Some((right_child_pointer, size)) => {
                let right_child = unsafe{right_child_pointer.as_ref()};
                Some((size, right_child))
            }
        };

        f.debug_struct("InternalNode")
            .field("weight", &self.weight)
            .field("weighted_self_affinity", &self.weighted_self_affinity)
            .field("left_child", &left_child_string)
            .field("right_child", &right_child_string)
            .field("parent", &self.parent)
            .finish()
    }

}


#[derive(Debug)]
#[allow(dead_code)]
pub enum TreeNode{
    Leaf(NonIncidentLeafNode),
    Internal(InternalNode)
}

#[allow(dead_code)]
impl TreeNode{

    fn parent(&self) -> Option<NonNull<TreeNode>>{
        match self{
            TreeNode::Leaf(NonIncidentLeafNode{parent, ..}) => *parent,
            TreeNode::Internal(InternalNode{parent, ..}) => *parent
        }
    }

    fn contribution(&self, smallest_coreset_self_affinity: Float) -> Float{
        match self{
            TreeNode::Leaf(NonIncidentLeafNode { weight, self_affinity, ..}) => weight.0*(self_affinity.0 + smallest_coreset_self_affinity),
            TreeNode::Internal(internal_node) => internal_node.contribution(smallest_coreset_self_affinity)
        }
    }

    fn smoothed_contribution(&self, smallest_coreset_self_affinity: Float, cost: Float, coreset_star_weight: Float) -> Float{
        match self{
            TreeNode::Leaf(NonIncidentLeafNode{weight, ..}) =>{
                let contribution = self.contribution(smallest_coreset_self_affinity);
                let smoothed_contribution = contribution/cost + weight.0/coreset_star_weight;
                smoothed_contribution
            },
            TreeNode::Internal(internal_node) =>{
                internal_node.smoothed_contribution(smallest_coreset_self_affinity, cost, coreset_star_weight)
            }
        }
    }

    fn weighted_self_affinities(&self) -> Float{
        match self{
            TreeNode::Leaf(NonIncidentLeafNode { weight, self_affinity, ..}) => weight.0*self_affinity.0,
            TreeNode::Internal(InternalNode{weighted_self_affinity, ..}) => *weighted_self_affinity
        }
    }
    pub fn weight(&self) -> Weight{
        match self{
            TreeNode::Leaf(NonIncidentLeafNode { weight, ..}) => *weight,
            TreeNode::Internal(InternalNode{weight, ..}) => Weight(*weight)
        }
    }

    pub fn index(&self) -> Index{
        match self{
            TreeNode::Leaf(NonIncidentLeafNode { index, ..}) => *index,
            TreeNode::Internal(_) => unreachable!("Internal nodes don't have indices")
        }
    
    }

    pub fn size(&self) -> usize{
        match self{
            TreeNode::Leaf(_) => 1,
            TreeNode::Internal(internal)=>{
                internal.left_child.as_ref().map_or(0, |(_, size)| *size) + internal.right_child.as_ref().map_or(0, |(_, size)| *size)
            }
        }
    }

    pub fn set_parent(&mut self, parent: NonNull<TreeNode>){
        match self{
            TreeNode::Leaf(leaf) => leaf.parent = Some(parent),
            TreeNode::Internal(internal) => internal.parent = Some(parent)
        }
    }

}


#[allow(dead_code)]
pub struct NonIncidentTree{
    root: Option<(NonNull<TreeNode>,usize)>,
    bump_allocator: Bump
}


impl Debug for NonIncidentTree{
    fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result {
        match &self.root{
            None =>{
                f.debug_struct("NonIncidentTree")
                .field("root", &Option::<()>::None)
                .finish()
            },
            Some((root_pointer, size)) =>{
                let tree_node = unsafe{root_pointer.as_ref()};
                f.debug_struct("NonIncidentTree")
                .field("size", &size)
                .field("root", tree_node)
                .finish()
            }
        }

    }

}

#[allow(dead_code)]
impl NonIncidentTree{
    pub fn new() -> Self{
        NonIncidentTree{
            root: None,
            bump_allocator: Bump::new()}
    }

    pub fn contribution(&self, smallest_coreset_self_affinity: Float) -> Float{
        match self.root{
            None => 0.0,
            Some((root_pointer, _)) => unsafe{root_pointer.as_ref().contribution(smallest_coreset_self_affinity)}
        }
    }

    pub fn smoothed_contribution(&self, smallest_coreset_self_affinity: Float, cost: Float, coreset_star_weight: Float) -> Float{
        match self.root{
            None => 0.0,
            Some((root_pointer, _)) => unsafe{root_pointer.as_ref().smoothed_contribution(smallest_coreset_self_affinity, cost, coreset_star_weight)}
        }
    }

    pub fn _sample_node(&self, smallest_coreset_self_affinity: Float, cost: Float, coreset_star_weight: Float, rng: &mut ThreadRng, smoothed:bool)-> Result<(&TreeNode,Float), Error>{
        match self.root{
            None => Err(Error::EmptyTree),
            Some((root_pointer, _)) =>{
                let mut maybe_node_pointer = Some(root_pointer);
                let mut prob = 1.0;
                while let Some(node_pointer) = maybe_node_pointer{
                    let node = unsafe{node_pointer.as_ref()};
                    match node{
                        TreeNode::Internal(internal_node)  =>{
                            let (lhs_contribution,rhs_contribution) =  match smoothed{
                                true =>{
                                    (internal_node.left_child.as_ref()
                                    .map_or(0.0, |(pointer, _)| {
                                        unsafe{
                                            (*pointer).as_ref()
                                            .smoothed_contribution(smallest_coreset_self_affinity, cost, coreset_star_weight)
                                    }}),
                                    internal_node.right_child.as_ref()
                                    .map_or(0.0, |(pointer, _)| {
                                        unsafe{
                                            (pointer.as_ref())
                                            .smoothed_contribution(smallest_coreset_self_affinity, cost, coreset_star_weight)}}))
                                },
                                false =>{
                                    (internal_node.left_child.as_ref()
                                    .map_or(0.0, |(pointer, _)| {
                                        unsafe{
                                            (pointer.as_ref())
                                            .contribution(smallest_coreset_self_affinity)
                                    }}),
                                    internal_node.right_child.as_ref()
                                    .map_or(0.0, |(pointer, _)| {
                                        unsafe{
                                            (pointer.as_ref())
                                            .contribution(smallest_coreset_self_affinity)}}))
                                }
                            };
                            let total_contribution = lhs_contribution + rhs_contribution;
                            let sample = rng.gen_range(0.0..total_contribution);
                            if sample < lhs_contribution{
                                maybe_node_pointer = internal_node.left_child.as_ref().map(|(pointer, _)| *pointer);
                                prob *= lhs_contribution/total_contribution;
                            }else{
                                maybe_node_pointer = internal_node.right_child.as_ref().map(|(pointer, _)| *pointer);
                                prob *= rhs_contribution/total_contribution;
                            }
                        },
                        TreeNode::Leaf(_) => return Ok((node,prob))
                    }
                }
                unreachable!("We should never reach here");
            }
        }
    }

    pub fn sample_node_smoothed(&self, smallest_coreset_self_affinity: Float, cost: Float, coreset_star_weight: Float, rng: &mut ThreadRng) -> Result<(&TreeNode,Float), Error>{
        self._sample_node(smallest_coreset_self_affinity, cost, coreset_star_weight, rng, true)
    }
    pub fn sample_node(&self, smallest_coreset_self_affinity: Float, rng: &mut ThreadRng) -> Result<(&TreeNode,Float), Error>{
        self._sample_node(smallest_coreset_self_affinity, 0.0, 0.0, rng, false)
    }

    pub fn _compute_sampling_probability(&self, node_pointer: *mut TreeNode, smallest_coreset_self_affinity: Float, cost: Float, coreset_star_weight: Float, smoothed:bool) -> Float{
        let node_to_sample = unsafe{& *node_pointer};

        let mut child_contribution = match smoothed{
            true => node_to_sample.smoothed_contribution(smallest_coreset_self_affinity, cost, coreset_star_weight),
            false => node_to_sample.contribution(smallest_coreset_self_affinity)
        };

        let mut maybe_parent_pointer = node_to_sample.parent();
        
        let mut prob = 1.0;
        while let Some(parent_pointer) = maybe_parent_pointer{
            if let TreeNode::Internal(ref internal_parent) = unsafe{parent_pointer.as_ref()}{
                let node_contribution = match smoothed{
                    true => internal_parent.smoothed_contribution(smallest_coreset_self_affinity, cost, coreset_star_weight),
                    false => internal_parent.contribution(smallest_coreset_self_affinity)
                };
                prob *= child_contribution/node_contribution;
                child_contribution = node_contribution;
                maybe_parent_pointer = unsafe{parent_pointer.as_ref()}.parent();
            }else{
                unreachable!("We should never reach here since leaf's don't have children");
            }
        }
        prob
    }

    pub fn compute_sampling_probability(&self, node_pointer: *mut TreeNode, smallest_coreset_self_affinity: Float) -> Float{
        self._compute_sampling_probability(node_pointer, smallest_coreset_self_affinity, 0.0, 0.0, false)
    }
    pub fn compute_smoothed_sampling_probability(&self, node_pointer: *mut TreeNode, smallest_coreset_self_affinity: Float, cost: Float, coreset_star_weight: Float) -> Float{
        self._compute_sampling_probability(node_pointer, smallest_coreset_self_affinity, cost, coreset_star_weight, true)
    }
    #[allow(unused_assignments)]
    pub fn delete_node(&mut self, node_pointer: *mut TreeNode) -> (Weight, SelfAffinity){
        // Delete node by updating it's anscestors all the way to the root. 
        // Since we use a bump allocator, we don't need to worry about freeing memory.
        // We just need to update the parent pointers of the node's ancestors.
        
        let node_to_delete = unsafe{&*node_pointer};
        let weight = node_to_delete.weight().0;
        let weighted_self_affinity = node_to_delete.weighted_self_affinities();
        let mut self_affinity = 0.0;

        if let TreeNode::Leaf(NonIncidentLeafNode{self_affinity: leaf_self_affinity, ..}) = node_to_delete{
            self_affinity = leaf_self_affinity.0;
        }else{
            unreachable!("Internal nodes don't have self affinities");
        }


        let mut child_pointer = unsafe{NonNull::new_unchecked(node_pointer)};
        let mut maybe_parent_pointer = node_to_delete.parent().clone();
        let mut delete_child = true;

        while let Some(mut parent_pointer) = maybe_parent_pointer{
            if let TreeNode::Internal(ref mut parent_internal_node) = unsafe{parent_pointer.as_mut()}{
                // update weights and weighted self affinities
                parent_internal_node.weight -= weight;
                parent_internal_node.weighted_self_affinity -= weighted_self_affinity;

                // find the child pointer to update
                let target_child_pointer = match &mut parent_internal_node.left_child{
                    None => &mut parent_internal_node.right_child,
                    Some((left_child_pointer,_)) =>{
                        if *left_child_pointer == child_pointer{
                            &mut parent_internal_node.left_child
                        }else{
                            &mut parent_internal_node.right_child                        
                        }
                    }
                };
                match delete_child{
                    true => *target_child_pointer = None,
                    false => {
                        *target_child_pointer = target_child_pointer.map(|(pointer, size)| (pointer, size-1));
                    }
                }
                // Now check both children are None. If so mark the parent for deletion too.
                if parent_internal_node.left_child.is_none() && parent_internal_node.right_child.is_none(){
                    delete_child = true;
                }else{
                    delete_child = false;
                }

                maybe_parent_pointer = parent_internal_node.parent.clone();
                child_pointer = parent_pointer;
            }else{
                unreachable!("We should never reach here since leaf's don't have children");
            }
            
        }
        // Now we are at the root. We need to see if we have to delete the root.
        if delete_child{
            self.root = None;
        }else{
            self.root = self.root.map(|(pointer, size)| (pointer, size-1));
        }
        
        return (Weight(weight), SelfAffinity(self_affinity));
    }

    pub fn insert_from_iterator<I>(&mut self, iterator: I,num_leaves:usize) -> Vec<NonNull<TreeNode>>
    where I: Iterator<Item = (Index,Weight,SelfAffinity)> + std::iter::ExactSizeIterator
    {
        // Given an iterator of leaf node data, we create a balanced binary tree in a bottom up fashion.
        // This is to help with cache locality and branch prediction.

        // first we allocate all the leaf nodes using a vec and get an array of pointers to them.
        // The vec will be allocated from the bump allocator so we don't need to worry about freeing memory.
        let _total_num_nodes = 2*num_leaves - 1;

        let leaf_vec = self.bump_allocator.alloc_slice_fill_iter(
            iterator.map(|(index, weight, self_affinity)|{
                TreeNode::Leaf(NonIncidentLeafNode{
                    index,
                    weight,
                    self_affinity,
                    parent: None
                })
            })
        );
        let leaf_pointers = leaf_vec.into_iter().map(
            |leaf| NonNull::from(leaf))
            .collect::<Vec<NonNull<TreeNode>>>();

        // Now we create the internal nodes in a bottom up fashion.
        let mut current_level = leaf_pointers.clone();
        while current_level.len() > 1{
            let mut next_level = Vec::with_capacity((current_level.len() + 1) / 2);

            for chunk in current_level.chunks_exact(2){
                let mut left_child_pointer = chunk[0];
                let mut right_child_pointer = chunk[1];
                

                let internal_node = TreeNode::Internal(InternalNode{
                    weight: unsafe{left_child_pointer.as_ref().weight().0 + right_child_pointer.as_ref().weight().0},
                    weighted_self_affinity: unsafe{left_child_pointer.as_ref().weighted_self_affinities() + right_child_pointer.as_ref().weighted_self_affinities()},
                    left_child: Some((left_child_pointer, unsafe{left_child_pointer.as_ref().size()})),
                    right_child: Some((right_child_pointer, unsafe{right_child_pointer.as_ref().size()})),
                    parent: None
                });
                let internal_node_pointer = self.bump_allocator.alloc(internal_node) as *mut TreeNode;
                let internal_node_pointer = unsafe{NonNull::new_unchecked(internal_node_pointer)};

                unsafe{
                    left_child_pointer.as_mut().set_parent(internal_node_pointer);
                    right_child_pointer.as_mut().set_parent(internal_node_pointer);
                }

                next_level.push(internal_node_pointer);
            }

            if current_level.len()%2 == 1{
                next_level.push(*current_level.last().unwrap());
            }
            current_level = next_level;
        }

        self.root = Some((current_level[0], unsafe{current_level[0].as_ref().size()}));
        leaf_pointers

    }

    pub fn insert_node(&mut self, index: Index, weight: Weight, self_affinity: SelfAffinity) -> *mut TreeNode{
        // We must maintain that leaf nodes are not moved in memory. 
        // Moving internal nodes is fine
        // println!("inserting");
        // println!("root pointer: {:?}", self.root);
        if let None = self.root{
            // println!("base case");
            let new_leaf_pointer = self.bump_allocator.alloc(TreeNode::Leaf(NonIncidentLeafNode{
                index,
                weight,
                self_affinity,
                parent: None
            })) as *mut TreeNode;

            let new_leaf_pointer = unsafe{NonNull::new_unchecked(new_leaf_pointer)};

            self.root = Some((new_leaf_pointer, 1));
            return new_leaf_pointer.as_ptr();
        } else{
            self.root = self.root.map(|(root_pointer, size)| (root_pointer, size+1));
            return self._insert_node(index, weight, self_affinity);
        }
    }

    fn _insert_node(&mut self, index: Index, weight: Weight, self_affinity: SelfAffinity) -> *mut TreeNode{
        let root_tree_node_pointer = self.root.unwrap().0;
        let mut tree_node_pointer = root_tree_node_pointer;
        while let TreeNode::Internal(internal_node) = unsafe{tree_node_pointer.as_mut()}{
            // first update the internal node's weight and weighted self affinity.
            internal_node.weight += weight.0;
            internal_node.weighted_self_affinity += weight.0*self_affinity.0;
            // Now we need to decide whether to go left or right.
            let target_subtree = match (internal_node.left_child,internal_node.right_child){
                (None,None) => unreachable!(),
                (Some(_), None) => &mut internal_node.right_child,
                (None, Some(_)) => unreachable!("We should always have a left child if we have a right child"),
                (Some((_,left_size)),Some((_,right_size))) =>{
                    if left_size <= right_size{
                        &mut internal_node.left_child
                    }else{
                        &mut internal_node.right_child
                    }
                }
            };
            match target_subtree.is_none(){
                true =>{
                    // We need to insert the new node here. We use the bump allocator to allocate the new leaf node.
                    let new_leaf_mut_ref = self.bump_allocator.alloc(
                        TreeNode::Leaf(NonIncidentLeafNode{
                            index,
                            weight,
                            self_affinity,
                            parent: Some(tree_node_pointer)
                        })
                    ) as &mut TreeNode;
                    let new_leaf_pointer = unsafe{NonNull::new_unchecked(new_leaf_mut_ref)};
                    // Update the internal node to point to the new leaf.
                    *target_subtree = target_subtree.map(|(_, size)| (new_leaf_pointer, size+1));
                    return new_leaf_pointer.as_ptr();
                },
                false =>{
                    // we need to increment the size of the subtree rooted at the target subtree and then traverse down the tree.
                    *target_subtree = target_subtree.map(|(pointer, size)| (pointer, size+1));
                    tree_node_pointer = target_subtree.unwrap().0;
                }
            }
        }
        // Now we are at a leaf node. We need to replace the leaf node with an internal node which has the leaf node and the new node as children.
        // Taking care to only allocate the new node. Otherwise we are just performing pointer surgery and updating nodes.
        match unsafe{tree_node_pointer.as_mut()}{
            TreeNode::Leaf(leaf) =>{
                // We need to replace the leaf node with an internal node which has the leaf node and the new node as children.
                // First get the old parent pointer
                let old_parent_pointer = leaf.parent;
                let new_internal_node_pointer = self.bump_allocator.alloc(TreeNode::Internal(InternalNode{
                    weight: leaf.weight.0 + weight.0,
                    weighted_self_affinity: leaf.weight.0*leaf.self_affinity.0 + weight.0*self_affinity.0,
                    left_child: Some((tree_node_pointer, 1)),
                    right_child: None,
                    parent: leaf.parent
                })) as *mut TreeNode;
                let mut new_internal_node_pointer = unsafe{NonNull::new_unchecked(new_internal_node_pointer)};
                // update the current leaf node to point to the new internal node.
                leaf.parent = Some(new_internal_node_pointer);
                // Create the new leaf node
                let new_leaf_pointer = self.bump_allocator.alloc(TreeNode::Leaf(NonIncidentLeafNode{
                    index,
                    weight,
                    self_affinity,
                    parent: Some(new_internal_node_pointer)
                })) as *mut TreeNode;
                let new_leaf_pointer = unsafe{NonNull::new_unchecked(new_leaf_pointer)};
                // Update the new internal node to point to the new leaf node.
                if let TreeNode::Internal(internal_node) = unsafe{new_internal_node_pointer.as_mut()}{
                    internal_node.right_child = Some((new_leaf_pointer, 1));
                }
                // update the parent of the new internal node to point to the new internal node if it exists.
                if let Some(mut parent_pointer) = old_parent_pointer{
                    let parent = unsafe{parent_pointer.as_mut()};
                    match parent{
                        TreeNode::Leaf(_) => unreachable!(),
                        TreeNode::Internal(parent_internal_node) =>{
                            if parent_internal_node.left_child.unwrap().0 == tree_node_pointer{
                                parent_internal_node.left_child = parent_internal_node.left_child.map(|(_, size)| (new_internal_node_pointer, size));
                            } else{
                                parent_internal_node.right_child = parent_internal_node.right_child.map(|(_, size)| (new_internal_node_pointer, size));
                            }
                        }
                    }
                }else{
                    // The leaf node was the root so we have to update the root pointer to point to the new internal node.
                    self.root = self.root.map(|(_, size)| (new_internal_node_pointer, size));
                }
                return new_leaf_pointer.as_ptr();
            },
            TreeNode::Internal(_) =>{
                // We need to insert the new node in the subtree rooted at the left or right child.
                unreachable!();
            }
        }
    }
}


// Tests
#[cfg(test)]
mod tests{
    use super::*;
    use crate::improved::common::Datapoint;
    use std::collections::HashSet;

    #[allow(dead_code)]
    fn compute_actual_total_contribution(data_points: &[Datapoint], deleted_indices: &HashSet<Index>, smallest_coreset_self_affinity: Float) -> Float{
        (0..data_points.len()).map(|i|{
            match deleted_indices.contains(&i.into()){
                true => 0.0,
                false => data_points[i].contribution(smallest_coreset_self_affinity)
            }
        }).sum::<Float>()
    }

    #[allow(dead_code)]
    fn compute_actual_total_smoothed_contribution(data_points: &[Datapoint], deleted_indices: &HashSet<Index>, smallest_coreset_self_affinity: Float, cost: Float, coreset_star_weight: Float) -> Float{
        (0..data_points.len()).map(|i|{
            match deleted_indices.contains(&i.into()){
                true => 0.0,
                false => data_points[i].smoothed_contribution(smallest_coreset_self_affinity, cost, coreset_star_weight.into())
            }
        }).sum::<Float>()
    }

    #[allow(dead_code)]
    fn test_total_contribution(tree: &NonIncidentTree, data_points: &[Datapoint], deleted_indices: &HashSet<Index>, smallest_coreset_self_affinity: Float){
        let total_contribution_actual = compute_actual_total_contribution(data_points, deleted_indices, smallest_coreset_self_affinity);
        let total_contribution_expected = tree.contribution(smallest_coreset_self_affinity);
        println!("Actual vs tree total contribution: {:?} vs {:?}", total_contribution_actual, total_contribution_expected);
        assert_eq!(total_contribution_actual, total_contribution_expected);
    }

    #[allow(dead_code)]
    fn test_sampling_probabilities(tree: &NonIncidentTree, data_points: &[Datapoint],deleted_indices: &HashSet<Index> ,smallest_coreset_self_affinity: Float, pointers: &[Option<*mut TreeNode>]){
        let total_contribution_of_indices = compute_actual_total_contribution(data_points, deleted_indices, smallest_coreset_self_affinity);

        let target_probs = (0..data_points.len()).map(|i|{
            match deleted_indices.contains(&i.into()){
                true => 0.0,
                false => data_points[i].contribution(smallest_coreset_self_affinity) / total_contribution_of_indices
            }
        }).collect::<Vec<Float>>();

        // dbg!(&tree);

        let tree_probs = (0..data_points.len()).map(|i|{
            match pointers[i]{
                None => 0.0,
                Some(pointer) => tree.compute_sampling_probability(pointer, smallest_coreset_self_affinity)
            }
        }).collect::<Vec<Float>>();

        println!("Target probs: {:?}", target_probs);
        println!("Tree probs: {:?}", tree_probs);

        for (target_prob, tree_prob) in target_probs.iter().zip(tree_probs.iter()){
            assert!((target_prob - tree_prob).abs() < 1e-6);
        }
    }
    
    #[allow(dead_code)]
    fn test_smoothed_sampling_probabilities(tree: &NonIncidentTree, data_points: &[Datapoint],deleted_indices: &HashSet<Index> ,smallest_coreset_self_affinity: Float, cost: Float, coreset_star_weight: Float, pointers: &[Option<*mut TreeNode>]){
        let total_contribution_of_indices = compute_actual_total_smoothed_contribution(data_points, deleted_indices, smallest_coreset_self_affinity, cost, coreset_star_weight);

        let target_probs = (0..data_points.len()).map(|i|{
            match deleted_indices.contains(&i.into()){
                true => 0.0,
                false => data_points[i].smoothed_contribution(smallest_coreset_self_affinity, cost, coreset_star_weight.into()) / total_contribution_of_indices
            }
        }).collect::<Vec<Float>>();

        // dbg!(&tree);

        let tree_probs = (0..data_points.len()).map(|i|{
            match pointers[i]{
                None => 0.0,
                Some(pointer) => tree.compute_smoothed_sampling_probability(pointer, smallest_coreset_self_affinity, cost, coreset_star_weight)
            }
        }).collect::<Vec<Float>>();

        println!("Target probs: {:?}", target_probs);
        println!("Tree probs: {:?}", tree_probs);

        for (target_prob, tree_prob) in target_probs.iter().zip(tree_probs.iter()){
            assert!((target_prob - tree_prob).abs() < 1e-6);
        }
    }

    #[test]
    pub fn test_insertion_and_deletion(){
        let mut tree = NonIncidentTree::new();
        let mut pointers = vec!();
        (0..20).for_each(|i|{
            pointers.push(tree.insert_node(Index(i), Weight((i) as Float), SelfAffinity((i) as Float * 10.0)));

        });

        for pointer in pointers.iter(){
            let node = unsafe{pointer.as_ref().unwrap()};
            println!("{:?}", node);
        }

        (0..20).rev().for_each(|i|{
            tree.delete_node(pointers[i]);
        });
    }

    pub fn assert_nodes_are_leaves(pointers: &[Option<*mut TreeNode>]){
        for pointer in pointers.iter(){
            match pointer{
                None => {},
                Some(pointer) =>{
                    let node = unsafe{pointer.as_ref().unwrap()};
                    match node{
                        TreeNode::Leaf(_) => {},
                        TreeNode::Internal(_) => panic!("Node is not a leaf")
                    }
                }
            }
        }
    }

    #[test]
    fn test_non_incident_tree_sample(){
        
        let data_points = vec![
            Datapoint{weight: Weight(1.0), self_affinity: SelfAffinity(1.0)},
            Datapoint{weight: Weight(2.0), self_affinity: SelfAffinity(2.0)},
            Datapoint{weight: Weight(3.0), self_affinity: SelfAffinity(3.0)},
            Datapoint{weight: Weight(4.0), self_affinity: SelfAffinity(4.0)},
            Datapoint{weight: Weight(5.0), self_affinity: SelfAffinity(5.0)},
            Datapoint{weight: Weight(6.0), self_affinity: SelfAffinity(6.0)},
            Datapoint{weight: Weight(7.0), self_affinity: SelfAffinity(7.0)},
            Datapoint{weight: Weight(8.0), self_affinity: SelfAffinity(8.0)},
        ];
        
        let mut tree = NonIncidentTree::new();
        let smallest_coreset_self_affinity = 2.0;

        // println!("insertion");
        // for i in 0..data_points.len(){
        //     println!("{:?}", "#".repeat(40));
        //     let pointer = tree.insert_node(i.into(), data_points[i].weight, data_points[i].self_affinity);
        //     println!("{:?}", &tree);
        //     pointers.push(Some(pointer));
        //     assert_nodes_are_leaves(&pointers);
        //     test_total_contribution(&tree, &data_points[..(i+1)], &HashSet::new(), smallest_coreset_self_affinity);
        //     test_sampling_probabilities(&tree, &data_points[..(i+1)], &HashSet::new(), smallest_coreset_self_affinity, &pointers);
        //     test_smoothed_sampling_probabilities(&tree, &data_points[..(i+1)], &HashSet::new(), smallest_coreset_self_affinity, 5.0, 100.0, &pointers);
        // }
        let mut pointers = tree.insert_from_iterator(
            data_points.iter().enumerate().map(|(i,datapoint)|{
                (i.into(), datapoint.weight, datapoint.self_affinity)
            }),
            data_points.len()
        ).into_iter().map(|x| Some(x)).collect::<Vec<Option<NonNull<TreeNode>>>>();


        let delete_order: Vec<usize> = vec![0,4,2,1,3,7,5,6];
        println!("deletion");
        for i in (0..delete_order.len()).into_iter(){
            println!("{:?}", "#".repeat(40));
            let _ = tree.delete_node(pointers[delete_order[i]].unwrap().as_ptr());
            pointers[delete_order[i]] = None;
            

            let pointers_as_is = &pointers.iter().map(|x| x.map(|y| y.as_ptr())).collect::<Vec<Option<*mut TreeNode>>>();
            assert_nodes_are_leaves(pointers_as_is);
            let deleted_indices = delete_order[..i+1].iter().map(|&x| x.into()).collect();
            println!("{:?}", &deleted_indices);
            test_total_contribution(&tree, &data_points, &deleted_indices, smallest_coreset_self_affinity);
            test_sampling_probabilities(&tree, &data_points, &deleted_indices, smallest_coreset_self_affinity, pointers_as_is);
            test_smoothed_sampling_probabilities(&tree, &data_points, &deleted_indices, smallest_coreset_self_affinity, 5.0, 100.0, pointers_as_is);
        }
    }


    #[test]
    fn test_bulk_insert(){
        let data_points = vec![
            Datapoint{weight: Weight(1.0), self_affinity: SelfAffinity(1.0)},
            Datapoint{weight: Weight(2.0), self_affinity: SelfAffinity(2.0)},
            Datapoint{weight: Weight(3.0), self_affinity: SelfAffinity(3.0)},
            Datapoint{weight: Weight(4.0), self_affinity: SelfAffinity(4.0)},
            Datapoint{weight: Weight(5.0), self_affinity: SelfAffinity(5.0)},
            Datapoint{weight: Weight(6.0), self_affinity: SelfAffinity(6.0)},
            Datapoint{weight: Weight(7.0), self_affinity: SelfAffinity(7.0)},
            Datapoint{weight: Weight(8.0), self_affinity: SelfAffinity(8.0)},
        ];
        let mut tree = NonIncidentTree::new();

        let mut _pointers = tree.insert_from_iterator(
            data_points.iter().enumerate().map(|(i,datapoint)|{
                (i.into(), datapoint.weight, datapoint.self_affinity)
            }),
            data_points.len()
        );


    }

}
