use std::fmt::Formatter;
use std::fmt::Debug;
use rand::rngs::StdRng;
use rand::{rngs::ThreadRng, Rng};
use bumpalo::Bump;
use std::ptr::NonNull;
use crate::fixed::common::{Index, Weight, SelfAffinity, Float, Error};

// MARK: - Node structs
#[derive(Debug)]
#[allow(dead_code)]
pub struct LeafNode{
    pub index: Index,
    pub weight: Weight,
    pub delta: Float,
    pub parent: Option<NonNull<TreeNode>>
}

impl LeafNode{
    pub fn contribution(&self) -> Float{
        self.weight.0*self.delta
    }
}

pub struct InternalNode{
    contribution: Float,
    weight: Weight,
    left_child: Option<NonNull<TreeNode>>,
    right_child: Option<NonNull<TreeNode>>,
    parent: Option<NonNull<TreeNode>>
}

impl InternalNode{

    #[allow(dead_code)]
    pub fn contribution(&self) -> Float{
        self.contribution
    }

    #[allow(dead_code)]
    pub fn smoothed_contribution(&self, cost: Float, coreset_star_weight: Float) -> Float{
        self.contribution/cost + self.weight.0/coreset_star_weight
    }
}




impl Debug for InternalNode{
    fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result {

        let left_child_string = match &self.left_child{
            None => None,
            Some(left_child_pointer) => {
                let left_child = unsafe{left_child_pointer.as_ref()};
                Some(left_child)
            }
        };

        let right_child_string = match &self.right_child{
            None => None,
            Some(right_child_pointer)=> {
                let right_child = unsafe{right_child_pointer.as_ref()};
                Some(right_child)
            }
        };

        f.debug_struct("InternalNode")
            .field("contribution", &self.contribution)
            .field("weight", &self.weight)
            .field("left_child", &left_child_string)
            .field("right_child", &right_child_string)
            .field("parent", &self.parent)
            .finish()
    }

}


#[derive(Debug)]
#[allow(dead_code)]
pub enum TreeNode{
    Leaf(LeafNode),
    Internal(InternalNode)
}

#[allow(dead_code)]
impl TreeNode{

    pub fn parent(&self) -> Option<NonNull<TreeNode>>{
        match self{
            TreeNode::Leaf(LeafNode{parent, ..}) => *parent,
            TreeNode::Internal(InternalNode{parent, ..}) => *parent
        }
    }

    pub fn contribution(&self) -> Float{
        match self{
            TreeNode::Leaf(leaf_node) => leaf_node.contribution(),
            TreeNode::Internal(internal_node) => internal_node.contribution()
        }
    }

    fn smoothed_contribution(&self, cost: Float, coreset_star_weight: Float) -> Float{
        match self{
            TreeNode::Leaf(LeafNode{weight, ..}) =>{
                let contribution = self.contribution();
                let smoothed_contribution = contribution/cost + weight.0/coreset_star_weight;
                smoothed_contribution
            },
            TreeNode::Internal(internal_node) =>{
                internal_node.smoothed_contribution(cost, coreset_star_weight)
            }
        }
    }

    pub fn weight(&self) -> Weight{
        match self{
            TreeNode::Leaf(LeafNode { weight, ..}) => *weight,
            TreeNode::Internal(InternalNode{weight, ..}) => *weight
        }
    }

    pub fn index(&self) -> Index{
        match self{
            TreeNode::Leaf(LeafNode { index, ..}) => *index,
            TreeNode::Internal(_) => unreachable!("Internal nodes don't have indices")
        }
    
    }


    pub fn set_parent(&mut self, parent: NonNull<TreeNode>){
        match self{
            TreeNode::Leaf(leaf) => leaf.parent = Some(parent),
            TreeNode::Internal(internal) => internal.parent = Some(parent)
        }
    }

    pub fn delta(&self) -> Float{
        match self{
            TreeNode::Leaf(LeafNode{delta, ..}) => *delta,
            TreeNode::Internal(_) => panic!("Internal nodes don't have deltas")
        }
    }

    pub fn children(&self) -> (Option<NonNull<TreeNode>>, Option<NonNull<TreeNode>>){
        match self{
            TreeNode::Leaf(_) => panic!("Leaf nodes don't have children"),
            TreeNode::Internal(InternalNode{left_child, right_child, ..}) => (*left_child, *right_child)
        }
    }


}


#[allow(dead_code)]
pub struct SamplingTree{
    root: Option<NonNull<TreeNode>>,
    bump_allocator: Bump
}


impl Debug for SamplingTree{
    fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result {
        match &self.root{
            None =>{
                f.debug_struct("SamplingTree")
                .field("root", &Option::<()>::None)
                .finish()
            },
            Some(root_pointer) =>{
                let tree_node = unsafe{root_pointer.as_ref()};
                f.debug_struct("SamplingTree")
                .field("root", tree_node)
                .finish()
            }
        }

    }

}

#[allow(dead_code)]
impl SamplingTree{
    pub fn new() -> Self{
        SamplingTree{
            root: None,
            bump_allocator: Bump::new()}
    }

    pub fn contribution(&self) -> Float{
        match self.root{
            None => 0.0,
            Some(root_pointer) => unsafe{root_pointer.as_ref().contribution()}
        }
    }

    pub fn smoothed_contribution(&self, cost: Float, coreset_star_weight: Float) -> Float{
        match self.root{
            None => 0.0,
            Some(root_pointer) => unsafe{root_pointer.as_ref().smoothed_contribution(cost, coreset_star_weight)}
        }
    }

    pub fn _sample_node(&mut self, cost: Float, coreset_star_weight: Float, rng: &mut StdRng, smoothed:bool)-> Result<(NonNull<TreeNode>,Float), Error>{
        match self.root{
            None => Err(Error::EmptyTree),
            Some(root_pointer) =>{
                let mut maybe_node_pointer = Some(root_pointer);
                let mut prob = 1.0;
                while let Some(node_pointer) = maybe_node_pointer{
                    let node = unsafe{node_pointer.as_ref()};
                    match node{
                        TreeNode::Internal(internal_node)  =>{
                            let (lhs_contribution,rhs_contribution) =  match smoothed{
                                true =>{
                                    (internal_node.left_child.as_ref()
                                    .map_or(0.0, |pointer| {
                                        unsafe{
                                            (*pointer).as_ref()
                                            .smoothed_contribution(cost, coreset_star_weight)
                                    }}),
                                    internal_node.right_child.as_ref()
                                    .map_or(0.0, |pointer| {
                                        unsafe{
                                            (pointer.as_ref())
                                            .smoothed_contribution(cost, coreset_star_weight)}}))
                                },
                                false =>{
                                    (internal_node.left_child.as_ref()
                                    .map_or(0.0, |pointer| {
                                        unsafe{
                                            (pointer.as_ref())
                                            .contribution()
                                    }}),
                                    internal_node.right_child.as_ref()
                                    .map_or(0.0, |pointer| {
                                        unsafe{
                                            (pointer.as_ref())
                                            .contribution()}}))
                                }
                            };
                            let total_contribution = lhs_contribution + rhs_contribution;
                            if total_contribution <= 0.0{
                                println!("A numerical error has probably occured. Might need to retry with higher float precision");
                                return Err(Error::NumericalError);
                            }
                            let sample = rng.gen_range(0.0..total_contribution);
                            if sample < lhs_contribution{
                                maybe_node_pointer = internal_node.left_child.as_ref().map(|pointer| *pointer);
                                prob *= lhs_contribution/total_contribution;
                            }else{
                                maybe_node_pointer = internal_node.right_child.as_ref().map(|pointer| *pointer);
                                prob *= rhs_contribution/total_contribution;
                            }
                        },
                        TreeNode::Leaf(_) => return Ok((node_pointer, prob))
                    }
                }
                unreachable!("We should never reach here");
            }
        }
    }

    pub fn sample_node_smoothed(&mut self, coreset_star_weight: Float, rng: &mut StdRng) -> Result<(NonNull<TreeNode>,Float), Error>{
        self._sample_node( self.contribution(), coreset_star_weight, rng, true)
    }
    pub fn sample_node(&mut self, rng: &mut StdRng) -> Result<(NonNull<TreeNode>,Float), Error>{
        self._sample_node( 0.0, 0.0, rng, false)
    }

    pub fn _compute_sampling_probability(&self, node_pointer: *mut TreeNode, cost: Float, coreset_star_weight: Float, smoothed:bool) -> Float{
        let node_to_sample = unsafe{& *node_pointer};

        let mut child_contribution = match smoothed{
            true => node_to_sample.smoothed_contribution( cost, coreset_star_weight),
            false => node_to_sample.contribution()
        };

        let mut maybe_parent_pointer = node_to_sample.parent();
        
        let mut prob = 1.0;
        while let Some(parent_pointer) = maybe_parent_pointer{
            if let TreeNode::Internal(ref internal_parent) = unsafe{parent_pointer.as_ref()}{
                let node_contribution = match smoothed{
                    true => internal_parent.smoothed_contribution(cost, coreset_star_weight),
                    false => internal_parent.contribution()
                };
                prob *= child_contribution/node_contribution;
                child_contribution = node_contribution;
                maybe_parent_pointer = unsafe{parent_pointer.as_ref()}.parent();
            }else{
                unreachable!("We should never reach here since leaf's don't have children");
            }
        }
        prob
    }

    pub fn compute_sampling_probability(&self, node_pointer: *mut TreeNode) -> Float{
        self._compute_sampling_probability(node_pointer, 0.0, 0.0, false)
    }
    pub fn compute_smoothed_sampling_probability(&self, node_pointer: *mut TreeNode, cost: Float, coreset_star_weight: Float) -> Float{
        self._compute_sampling_probability(node_pointer, cost, coreset_star_weight, true)
    }

    pub fn insert_from_iterator<I>(&mut self, iterator: I,num_leaves:usize, min_self_affinity:Float) -> Vec<NonNull<TreeNode>>
    where I: Iterator<Item = (Index,Weight,SelfAffinity)> + std::iter::ExactSizeIterator
    {
        // Given an iterator of leaf node data, we create a balanced binary tree in a bottom up fashion.
        // This is to help with cache locality and branch prediction.

        // first we allocate all the leaf nodes using a vec and get an array of pointers to them.
        // The vec will be allocated from the bump allocator so we don't need to worry about freeing memory.
        let _total_num_nodes = 2*num_leaves - 1;

        let leaf_vec = self.bump_allocator.alloc_slice_fill_iter(
            iterator.map(|(index, weight, self_affinity)|{
                let delta = self_affinity.0 + min_self_affinity;
                TreeNode::Leaf(LeafNode{
                    index,
                    weight,
                    delta,
                    parent: None
                })
            })
        );
        let leaf_pointers = leaf_vec.
            iter_mut().map(
            |leaf| NonNull::from(leaf))
            .collect::<Vec<NonNull<TreeNode>>>();

        // Now we create the internal nodes in a bottom up fashion.
        let mut current_level = leaf_pointers.clone();
        while current_level.len() > 1{
            let mut next_level = Vec::with_capacity((current_level.len() + 1) / 2);

            for chunk in current_level.chunks_exact(2){
                let internal_node_pointer = {
                    let left_child_pointer = chunk[0];
                    let right_child_pointer = chunk[1];
                    

                    let internal_node = TreeNode::Internal(InternalNode{
                        weight: Weight(unsafe{left_child_pointer.as_ref().weight().0 + right_child_pointer.as_ref().weight().0}),
                        contribution: unsafe{left_child_pointer.as_ref().contribution() + right_child_pointer.as_ref().contribution()},
                        left_child: Some(left_child_pointer),
                        right_child: Some(right_child_pointer),
                        parent: None
                    });
                    let internal_node_pointer = self.bump_allocator.alloc(internal_node) as *mut TreeNode;
                    let internal_node_pointer = unsafe{NonNull::new_unchecked(internal_node_pointer)};
                    internal_node_pointer
                };
                let mut left_child_pointer = chunk[0];
                let mut right_child_pointer = chunk[1];
                unsafe{
                    left_child_pointer.as_mut().set_parent(internal_node_pointer);
                    right_child_pointer.as_mut().set_parent(internal_node_pointer);
                }

                next_level.push(internal_node_pointer);
            }

            if current_level.len()%2 == 1{
                next_level.push(*current_level.last().unwrap());
            }
            current_level = next_level;
        }

        self.root = Some(current_level[0]);
        leaf_pointers

    }

    pub fn update_delta(&mut self,  mut node_pointer: NonNull<TreeNode>, new_delta: Float){
        match unsafe{node_pointer.as_mut()}{
            TreeNode::Internal(_) => panic!("should be a leaf node"),
            TreeNode::Leaf(leaf_node) => {
                if leaf_node.delta <= new_delta{
                    return;
                }
                // update the delta and propagate the change up the tree.
                let delta_diff = leaf_node.delta - new_delta;
                let contribution_diff = delta_diff*leaf_node.weight.0;
                leaf_node.delta = new_delta;

                let mut maybe_parent_pointer = leaf_node.parent;

                while let Some(parent_pointer) = maybe_parent_pointer{
                    let parent = unsafe{parent_pointer.as_ptr().as_mut().unwrap()};
                    match parent{
                        TreeNode::Internal(ref mut internal_parent) => {
                            internal_parent.contribution -= contribution_diff;
                            maybe_parent_pointer = internal_parent.parent;
                        },
                        TreeNode::Leaf(_) => unreachable!("Leaf nodes don't have children")
                    }
                }
            }
        }
    }
}


// Tests
#[cfg(test)]
mod tests{
    use super::*;
    use crate::fixed::common::Datapoint;
    use std::collections::HashSet;

    #[allow(dead_code)]
    fn compute_actual_total_contribution(data_points: &[Datapoint], deleted_indices: &HashSet<Index>, smallest_coreset_self_affinity: Float) -> Float{
        (0..data_points.len()).map(|i|{
            match deleted_indices.contains(&i.into()){
                true => 0.0,
                false => data_points[i].contribution(smallest_coreset_self_affinity)
            }
        }).sum::<Float>()
    }

    #[allow(dead_code)]
    fn compute_actual_total_smoothed_contribution(data_points: &[Datapoint], deleted_indices: &HashSet<Index>, smallest_coreset_self_affinity: Float, cost: Float, coreset_star_weight: Float) -> Float{
        (0..data_points.len()).map(|i|{
            match deleted_indices.contains(&i.into()){
                true => 0.0,
                false => data_points[i].smoothed_contribution(smallest_coreset_self_affinity, cost, coreset_star_weight.into())
            }
        }).sum::<Float>()
    }

    #[allow(dead_code)]
    fn test_total_contribution(tree: &SamplingTree, data_points: &[Datapoint], deleted_indices: &HashSet<Index>, smallest_coreset_self_affinity: Float){
        let total_contribution_actual = compute_actual_total_contribution(data_points, deleted_indices, smallest_coreset_self_affinity);
        let total_contribution_expected = tree.contribution();
        println!("Actual vs tree total contribution: {:?} vs {:?}", total_contribution_actual, total_contribution_expected);
        assert_eq!(total_contribution_actual, total_contribution_expected);
    }

    #[allow(dead_code)]
    fn test_sampling_probabilities(tree: &SamplingTree, data_points: &[Datapoint],deleted_indices: &HashSet<Index>,
        smallest_coreset_self_affinity: Float, pointers: &Vec<Option<NonNull<TreeNode>>>){
        let total_contribution_of_indices = compute_actual_total_contribution(data_points, deleted_indices, smallest_coreset_self_affinity);

        let target_probs = (0..data_points.len()).map(|i|{
            match deleted_indices.contains(&i.into()){
                true => 0.0,
                false => data_points[i].contribution(smallest_coreset_self_affinity) / total_contribution_of_indices
            }
        }).collect::<Vec<Float>>();

        // dbg!(&tree);

        let tree_probs = (0..data_points.len()).map(|i|{
            match pointers[i]{
                None => 0.0,
                Some(pointer) => tree.compute_sampling_probability(pointer.as_ptr())
            }
        }).collect::<Vec<Float>>();

        println!("Target probs: {:?}", target_probs);
        println!("Tree probs: {:?}", tree_probs);

        for (target_prob, tree_prob) in target_probs.iter().zip(tree_probs.iter()){
            assert!((target_prob - tree_prob).abs() < 1e-6);
        }
    }
    
    #[allow(dead_code)]
    fn test_smoothed_sampling_probabilities(tree: &SamplingTree, data_points: &[Datapoint],deleted_indices: &HashSet<Index> ,smallest_coreset_self_affinity: Float, cost: Float,
         coreset_star_weight: Float, pointers: &Vec<Option<NonNull<TreeNode>>>){
        let total_contribution_of_indices = compute_actual_total_smoothed_contribution(data_points, deleted_indices, smallest_coreset_self_affinity, cost, coreset_star_weight);

        let target_probs = (0..data_points.len()).map(|i|{
            match deleted_indices.contains(&i.into()){
                true => 0.0,
                false => data_points[i].smoothed_contribution(smallest_coreset_self_affinity, cost, coreset_star_weight.into()) / total_contribution_of_indices
            }
        }).collect::<Vec<Float>>();

        // dbg!(&tree);

        let tree_probs = (0..data_points.len()).map(|i|{
            match pointers[i]{
                None => 0.0,
                Some(pointer) => tree.compute_smoothed_sampling_probability(pointer.as_ptr(), cost, coreset_star_weight)
            }
        }).collect::<Vec<Float>>();

        println!("Target probs: {:?}", target_probs);
        println!("Tree probs: {:?}", tree_probs);

        for (target_prob, tree_prob) in target_probs.iter().zip(tree_probs.iter()){
            assert!((target_prob - tree_prob).abs() < 1e-6);
        }
    }




    #[test]
    fn test_non_incident_tree_sample(){
        let data_points = vec![
            Datapoint{weight: Weight(1.0), self_affinity: SelfAffinity(1.0)},
            Datapoint{weight: Weight(2.0), self_affinity: SelfAffinity(2.0)},
            Datapoint{weight: Weight(3.0), self_affinity: SelfAffinity(3.0)},
            Datapoint{weight: Weight(4.0), self_affinity: SelfAffinity(4.0)},
            Datapoint{weight: Weight(5.0), self_affinity: SelfAffinity(5.0)},
            Datapoint{weight: Weight(6.0), self_affinity: SelfAffinity(6.0)},
            Datapoint{weight: Weight(7.0), self_affinity: SelfAffinity(7.0)},
            Datapoint{weight: Weight(8.0), self_affinity: SelfAffinity(8.0)},
        ];
        
        let mut tree = SamplingTree::new();
        let smallest_coreset_self_affinity = 1.0;



        let pointers = tree.insert_from_iterator(
            data_points.iter().enumerate().map(|(i,datapoint)|{
                (i.into(), datapoint.weight, datapoint.self_affinity)
            }),
            data_points.len()
        , smallest_coreset_self_affinity).into_iter().map(|x| Some(x)).collect::<Vec<Option<NonNull<TreeNode>>>>();

        // test the probabilities
        test_sampling_probabilities(&tree, &data_points, &HashSet::new(), smallest_coreset_self_affinity, &pointers);
        test_smoothed_sampling_probabilities(&tree, &data_points, &HashSet::new(), smallest_coreset_self_affinity, 10.0, 5.0, &pointers);
        
        }
    }



