
use std::collections::HashSet;
use faer::{sparse::SparseColMatRef, MatRef};
use rand::{rngs::ThreadRng, Rng};





// import common data types from the sibling module
use crate::common::{Index, Weight, SelfAffinity, Datapoint, Float, Error, FxHashSet};


// MARK: -Node structs

#[derive(Debug)]
pub struct NonIncidentLeafNode{
    index: Index,
    weight: Weight,
    self_affinity: SelfAffinity,
}

impl NonIncidentLeafNode{
    pub fn index(&self) -> Index{
        self.index
    }
    pub fn weight(&self) -> Weight{
        self.weight
    }
    pub fn self_affinity(&self) -> SelfAffinity{
        self.self_affinity
    }
}

#[derive(Debug)]
struct InternalNode{
    left_indices: FxHashSet<Index>,
    right_indices: FxHashSet<Index>,
    left_subtree: Option<Box<TreeNode>>,
    right_subtree: Option<Box<TreeNode>>,
    weight: Float,
    weighted_self_affinity: Float,
}


// MARK: -Node enum
#[derive(Debug)]
enum TreeNode{
    Leaf(NonIncidentLeafNode),
    Internal(InternalNode)
}

impl TreeNode{
    fn contribution(&self, smallest_coreset_self_affinity: Float) -> Float{
        match self{
            TreeNode::Leaf(NonIncidentLeafNode { weight, self_affinity, ..}) => weight.0*(self_affinity.0 + smallest_coreset_self_affinity),
            TreeNode::Internal(InternalNode{weight, weighted_self_affinity, ..}) => weighted_self_affinity + weight * smallest_coreset_self_affinity
        }
    }

    fn smoothed_contribution(&self, smallest_coreset_self_affinity: Float, cost: Float, coreset_star_weight: Float) -> Float{
        match self{
            TreeNode::Leaf(NonIncidentLeafNode{weight, ..}) =>{
                let contribution = self.contribution(smallest_coreset_self_affinity);
                let smoothed_contribution = contribution/cost + weight.0/coreset_star_weight;
                smoothed_contribution
            },
            TreeNode::Internal(InternalNode{weight, ..}) =>{
                let contribution = self.contribution(smallest_coreset_self_affinity);
                let smoothed_contribution = contribution/cost + weight/coreset_star_weight;
                smoothed_contribution
            }
        }
    }

    fn weighted_self_affinities(&self) -> Float{
        match self{
            TreeNode::Leaf(NonIncidentLeafNode { weight, self_affinity, ..}) => weight.0*self_affinity.0,
            TreeNode::Internal(InternalNode{weighted_self_affinity, ..}) => *weighted_self_affinity
        }
    }
    fn weight(&self) -> Float{
        match self{
            TreeNode::Leaf(NonIncidentLeafNode { weight, ..}) => weight.0,
            TreeNode::Internal(InternalNode{weight, ..}) => *weight
        }
    }

}


// MARK: -Tree struct

#[derive(Debug)]
pub struct NonIncidentTree{
    root: Option<TreeNode>
}

impl NonIncidentTree{

    pub fn from_kernel_and_weights(kernel_matrix: SparseColMatRef<usize,Float>, weights: MatRef<Float>) -> Option<NonIncidentTree>{
        assert_eq!(kernel_matrix.nrows(), kernel_matrix.ncols());
        assert_eq!(kernel_matrix.nrows(), weights.nrows());
        assert_eq!(1, weights.ncols());

        match kernel_matrix.nrows(){
            0 => None,
            _ => Some(NonIncidentTree{
                root: Some(NonIncidentTree::_from_kernel_and_weights(0, kernel_matrix.nrows(), kernel_matrix, weights))
            })
        }

    }

    fn _from_kernel_and_weights(start_index: usize, end_index: usize, kernel_matrix: SparseColMatRef<usize,Float>, weights: MatRef<Float>) -> TreeNode{
        match end_index - start_index{
            0..=1 => {
                let index: Index = start_index.into();
                let weight = Weight(*weights.get(index.0, 0));
                let self_affinity = SelfAffinity(*kernel_matrix.get(index.0, index.0)
                .expect(&format!("Make sure self affinites are present down the diagonal of K. No self-affinity found for datapoint {:?}", index.0)));
                TreeNode::Leaf(NonIncidentLeafNode{index, weight, self_affinity})
            },
            _ => {
                // split the data points into two sets
                let midpoint = (start_index + end_index) / 2;
                let left_start_index = start_index;
                let left_end_index = midpoint;
                let right_start_index = midpoint;
                let right_end_index = end_index;
                // construct the left and right subtrees
                let left_set: FxHashSet<Index> = (left_start_index..left_end_index).map(|i| i.into()).collect();
                let left_subtree =  if left_set.len() >0 {Some(Box::new(NonIncidentTree::_from_kernel_and_weights(left_start_index, left_end_index, kernel_matrix, weights)) )} else {None};
                let right_set: FxHashSet<Index> = (right_start_index..right_end_index).map(|i| i.into()).collect();
                let right_subtree = if right_set.len()>0 {Some(Box::new(NonIncidentTree::_from_kernel_and_weights(right_start_index, right_end_index, kernel_matrix, weights)) )} else {None};
                // calculate the contribution of the internal node

                let weighted_self_affinity = (left_subtree.as_ref()).map_or(0.0, |l|l.weighted_self_affinities()) + right_subtree.as_ref().map_or(0.0, |r|r.weighted_self_affinities());
                let weight = left_subtree.as_ref().map_or(0.0, |l|l.weight()) + right_subtree.as_ref().map_or(0.0, |r|r.weight());

                TreeNode::Internal(
                    InternalNode { 
                        left_indices: left_set, 
                        right_indices: right_set, 
                        left_subtree: left_subtree, 
                        right_subtree: right_subtree, 
                        weighted_self_affinity: weighted_self_affinity,
                        weight: weight
                    }
                )
            }
        }
    }

    #[allow(dead_code)]
    pub fn from_data_points(data_points: &[Datapoint]) -> Option<NonIncidentTree>{
        match data_points.len(){
            0 => None,
            _ => Some(NonIncidentTree{
                root: Some(NonIncidentTree::_from_data_points(0, data_points.len(), data_points)),
            })
        }
    }

    pub fn contribution(&self, smallest_coreset_self_affinity: Float) -> Float{
        match self.root.as_ref(){
            None => 0.0,
            Some(root) => root.contribution(smallest_coreset_self_affinity)
        }
    }

    pub fn smoothed_contribution(&self, smallest_coreset_self_affinity: Float, cost: Float, coreset_star_weight: Float) -> Float{
        match self.root.as_ref(){
            None => 0.0,
            Some(root) => root.smoothed_contribution(smallest_coreset_self_affinity, cost, coreset_star_weight)
        }
    }

    fn _from_data_points(start_index: usize, end_index: usize, data_points:&[Datapoint]) -> TreeNode{
        match end_index - start_index{
            0..=1 => TreeNode::Leaf(NonIncidentLeafNode { index: start_index.into(), weight: data_points[start_index].weight, self_affinity: data_points[start_index].self_affinity }),
            _ => {
                // split the data points into two sets
                let midpoint = (start_index + end_index) / 2;
                let left_start_index = start_index;
                let left_end_index = midpoint;
                let right_start_index = midpoint;
                let right_end_index = end_index;
                // construct the left and right subtrees
                let left_set: FxHashSet<Index> = (left_start_index..left_end_index).map(|i| i.into()).collect();
                let left_subtree =  if left_set.len() >0 {Some(Box::new(NonIncidentTree::_from_data_points(left_start_index, left_end_index, data_points)))}  else {None};
                let right_set: FxHashSet<Index> = (right_start_index..right_end_index).map(|i| i.into()).collect();
                let right_subtree = if right_set.len()>0 {Some(Box::new(NonIncidentTree::_from_data_points(right_start_index, right_end_index, data_points)))} else {None};
                // calculate the contribution of the internal node

                let weighted_self_affinity = (left_subtree.as_ref()).map_or(0.0, |l|l.weighted_self_affinities()) + right_subtree.as_ref().map_or(0.0, |r|r.weighted_self_affinities());
                let weight = left_subtree.as_ref().map_or(0.0, |l|l.weight()) + right_subtree.as_ref().map_or(0.0, |r|r.weight());

                TreeNode::Internal(
                    InternalNode { 
                        left_indices: left_set, 
                        right_indices: right_set, 
                        left_subtree: left_subtree, 
                        right_subtree: right_subtree, 
                        weighted_self_affinity: weighted_self_affinity,
                        weight: weight
                    }
                )
            }
        }
    }

    pub fn delete_node(&mut self, index: Index) -> Result<(Weight,SelfAffinity),Error>{
        // given the index of a datapoint, delete its corresponding leaf node and update the path to the root
        // to reflect the deletion
        match self.root.as_mut() {
            None => return Err(Error::EmptyTree),
            Some( root) =>{
                let (weight, self_affinity, empty_tree) = NonIncidentTree::_delete_node(  root, index)?;
                if empty_tree{
                    self.root = None;
                }
                Ok((weight, self_affinity))
            }
        }

    }

    fn _delete_node(tree_node: &mut TreeNode, target_index: Index) -> Result<(Weight,SelfAffinity,bool),Error>{
        // return the contribution of the deleted node and a boolean indicating whether the current node is now a leaf node (and can itself be deleted)
        match tree_node{
            TreeNode::Leaf(NonIncidentLeafNode{index,weight,self_affinity}) =>{
                if target_index == *index{
                    Ok((*weight, (self_affinity.0 * weight.0).into(), true))
                }else{
                    Err(Error::NodeNotFound(target_index))
                }
            },
            TreeNode::Internal(InternalNode{
                left_indices,
                right_indices,
                left_subtree,
                right_subtree,
                weight,
                weighted_self_affinity}) =>{
                    match (left_indices.contains(&target_index),right_indices.contains(&target_index)){
                        (false,false) => Err(Error::NodeNotFound(target_index)),
                        (true,false) => {
                            // delete the node from the left subtree and update our current node.
                            // Depending on whether the left and right subtrees are now leaf nodes, we may need to tell the parent node to delete the current node.
                            // Regardless, we need to return the weight and self-affinity of the deleted node to the parent node.
                            let (deleted_weight, deleted_self_affinity, left_subtree_is_leaf) = NonIncidentTree::_delete_node(&mut *left_subtree.as_mut().unwrap(), target_index)?;
                            match (right_subtree.is_none(), left_subtree_is_leaf){
                                (true,true) =>{
                                    // both subtrees are now empty, so we can signal the parent node to delete the current node
                                    return Ok((deleted_weight, deleted_self_affinity, true))
                                },
                                (_,false) =>{
                                    // the left subtree is still an internal node, so we need to update the current node then pass the update up the tree
                                    *weight -= deleted_weight.0;
                                    *weighted_self_affinity -= deleted_self_affinity.0;
                                    left_indices.remove(&target_index);
                                    return Ok((deleted_weight, deleted_self_affinity, false))
                                },
                                (false, true) =>{
                                    // the right subtree is now the only subtree, so we delete the left subtree
                                    *weight -= deleted_weight.0;
                                    *weighted_self_affinity -= deleted_self_affinity.0;
                                    *left_subtree = None;
                                    left_indices.remove(&target_index);
                                    return Ok((deleted_weight, deleted_self_affinity, false))
                                }
                            }
                        },
                        (false,true) => {
                            // delete the node from the right subtree and update our current node.
                            let (deleted_weight, deleted_self_affinity, right_subtree_is_leaf) = NonIncidentTree::_delete_node(&mut *right_subtree.as_mut().unwrap(), target_index)?;
                            match (left_subtree.is_none(), right_subtree_is_leaf){
                                (true,true) =>{
                                    // both subtrees are now leaf nodes, so we can signal the parent node to delete the current node
                                    return Ok((deleted_weight, deleted_self_affinity, true))
                                },
                                (_,false) =>{
                                    // the right subtree is still an internal node, so we need to update the current node then pass the update up the tree
                                    *weight -= deleted_weight.0;
                                    *weighted_self_affinity -= deleted_self_affinity.0;
                                    right_indices.remove(&target_index);
                                    return Ok((deleted_weight, deleted_self_affinity, false))
                                },
                                (false, true) =>{
                                    // the left subtree is now the only subtree, so we delete the right subtree
                                    *weight -= deleted_weight.0;
                                    *weighted_self_affinity -= deleted_self_affinity.0;
                                    *right_subtree = None;
                                    right_indices.remove(&target_index);
                                    return Ok((deleted_weight, deleted_self_affinity, false))
                                
                                }
                            }
                        },
                        (true,true) => {
                            Err(Error::NodeInBothSubtrees(target_index))
                        }
                    }
                }
        }
    }

    pub fn sample_node_smoothed(&self, smallest_coreset_self_affinity: Float, cost: Float, coreset_star_weight: Float, rng: &mut ThreadRng) -> Result<(&NonIncidentLeafNode,Float),Error>{

        let mut node_ref = self.root.as_ref().ok_or(Error::EmptyTree)?;
        let mut prob = 1.0;
        while let TreeNode::Internal(node) = node_ref{
            let lhs_contribution = node.left_subtree.as_ref().map_or(0.0, |l|l.smoothed_contribution(smallest_coreset_self_affinity, cost, coreset_star_weight));
            let rhs_contribution = node.right_subtree.as_ref().map_or(0.0, |r|r.smoothed_contribution(smallest_coreset_self_affinity, cost, coreset_star_weight));
            let total_contribution = lhs_contribution + rhs_contribution;

            // go down the left subtree with probability proportional to the contribution of the left subtree
            let sample = rng.gen::<Float>() * total_contribution;
            if sample <= lhs_contribution{
                prob *= lhs_contribution / total_contribution;
                node_ref = node.left_subtree.as_ref().unwrap();
            }else{
                prob *= rhs_contribution / total_contribution;
                node_ref = node.right_subtree.as_ref().unwrap();
            }
        }
        // we have arrived at a leaf node. Return the index of the leaf node
        let leaf_node = match node_ref{
            TreeNode::Leaf(leaf_node) => leaf_node,
            _ => unreachable!()
        };
        Ok((&leaf_node,prob))
    }

    pub fn sample_node(&self, smallest_coreset_self_affinity: Float, rng: &mut ThreadRng) -> Result<&NonIncidentLeafNode,Error>{
        // sample a node from the tree proportional to its contribution
        let mut node_ref = self.root.as_ref().ok_or(Error::EmptyTree)?;

        while let TreeNode::Internal(node) = node_ref{
            let lhs_contribution = node.left_subtree.as_ref().map_or(0.0, |l|l.contribution(smallest_coreset_self_affinity));
            let rhs_contribution = node.right_subtree.as_ref().map_or(0.0, |r|r.contribution(smallest_coreset_self_affinity));
            let total_contribution = lhs_contribution + rhs_contribution;

            // go down the left subtree with probability proportional to the contribution of the left subtree
            let sample = rng.gen::<Float>() * total_contribution;
            if sample <= lhs_contribution{
                node_ref = node.left_subtree.as_ref().unwrap();
            }else{
                node_ref = node.right_subtree.as_ref().unwrap();
            }
        }
        // we have arrived at a leaf node. Return the index of the leaf node
        let leaf_node = match node_ref{
            TreeNode::Leaf(leaf_node) => leaf_node,
            _ => unreachable!()
        };
        Ok(&leaf_node)
    }

    #[allow(dead_code)]
    pub fn compute_sampling_probability(&self, index: Index, smallest_coreset_self_affinity: Float) -> Result<Float,Error>{
        // given that we sample a node from the tree,
        // compute the conditional probability of sampling the index-th node
        let mut node_ref = match self.root.as_ref(){
            None => return Ok(0.0),
            Some(root) => root
        };
        let mut probability = 1.0;

        while let TreeNode::Internal(node) = node_ref{
            let (target_subtree, other_subtree) = {
                if node.left_indices.contains(&index){
                    (&node.left_subtree,&node.right_subtree)
                }else if node.right_indices.contains(&index){
                    (&node.right_subtree, &node.left_subtree)
                }else{
                    return Ok(0.0);
                }
            };
            let target_contribution = target_subtree.as_ref().map_or(0.0, |t|t.contribution(smallest_coreset_self_affinity));
            let other_contribution = other_subtree.as_ref().map_or(0.0, |o|o.contribution(smallest_coreset_self_affinity));
            let total_contribution = target_contribution + other_contribution;
            probability *= target_contribution / total_contribution;
            node_ref = target_subtree.as_ref().unwrap();
            }
        Ok(probability)
    }

    #[allow(dead_code)]
    pub fn compute_smoothed_sampling_probability(&self, index: Index, smallest_coreset_self_affinity: Float, cost: Float, coreset_star_weight: Float) -> Result<Float,Error>{
        // given that we sample a node from the tree,
        // compute the conditional probability of sampling the index-th node
        let mut node_ref = match self.root.as_ref(){
            None => return Ok(0.0),
            Some(root) => root
        };
        let mut probability = 1.0;

        while let TreeNode::Internal(node) = node_ref{
            let (target_subtree, other_subtree) = {
                if node.left_indices.contains(&index){
                    (&node.left_subtree,&node.right_subtree)
                }else if node.right_indices.contains(&index){
                    (&node.right_subtree, &node.left_subtree)
                }else{
                    return Ok(0.0);
                }
            };
            let target_contribution = target_subtree.as_ref().map_or(0.0, |t|t.smoothed_contribution(smallest_coreset_self_affinity, cost, coreset_star_weight));
            let other_contribution = other_subtree.as_ref().map_or(0.0, |o|o.smoothed_contribution(smallest_coreset_self_affinity, cost, coreset_star_weight));
            let total_contribution = target_contribution + other_contribution;
            probability *= target_contribution / total_contribution;
            node_ref = target_subtree.as_ref().unwrap();
            }
        Ok(probability)
    }


}



// MARK: -Tests


mod tests{
    use super::*;

    #[allow(dead_code)]
    fn compute_actual_total_contribution(data_points: &[Datapoint], deleted_indices: &HashSet<Index>, smallest_coreset_self_affinity: Float) -> Float{
        (0..data_points.len()).map(|i|{
            match deleted_indices.contains(&i.into()){
                true => 0.0,
                false => data_points[i].contribution(smallest_coreset_self_affinity)
            }
        }).sum::<Float>()

    }

    #[allow(dead_code)]
    fn test_total_contribution(tree: &NonIncidentTree, data_points: &[Datapoint], deleted_indices: &HashSet<Index>, smallest_coreset_self_affinity: Float){
        let total_contribution_actual = compute_actual_total_contribution(data_points, deleted_indices, smallest_coreset_self_affinity);
        let total_contribution_expected = tree.contribution(smallest_coreset_self_affinity);
        println!("Actual vs tree total comtribution: {:?} vs {:?}", total_contribution_actual, total_contribution_expected);
        assert_eq!(total_contribution_actual, total_contribution_expected);
    }

    #[allow(dead_code)]
    fn test_sampling_probabilities(tree: &NonIncidentTree, data_points: &[Datapoint],deleted_indices: &HashSet<Index> ,smallest_coreset_self_affinity: Float){
        let total_contribution_of_indices = compute_actual_total_contribution(data_points, deleted_indices, smallest_coreset_self_affinity);

        let target_probs = (0..data_points.len()).map(|i|{
            match deleted_indices.contains(&i.into()){
                true => 0.0,
                false => data_points[i].contribution(smallest_coreset_self_affinity) / total_contribution_of_indices
            }
        }).collect::<Vec<Float>>();

        // dbg!(&tree);

        let tree_probs = (0..data_points.len()).map(|i|{
            tree.compute_sampling_probability(i.into(), smallest_coreset_self_affinity).unwrap()
        }).collect::<Vec<Float>>();

        // println!("Target probs: {:?}", target_probs);
        // println!("Tree probs: {:?}", tree_probs);

        for (target_prob, tree_prob) in target_probs.iter().zip(tree_probs.iter()){
            assert!((target_prob - tree_prob).abs() < 1e-6);
        }
    }

    #[test]
    fn test_non_incident_tree_sample(){
        
        let data_points = vec![
            Datapoint{weight: Weight(1.0), self_affinity: SelfAffinity(1.0)},
            Datapoint{weight: Weight(2.0), self_affinity: SelfAffinity(2.0)},
            Datapoint{weight: Weight(3.0), self_affinity: SelfAffinity(3.0)},
            Datapoint{weight: Weight(4.0), self_affinity: SelfAffinity(4.0)},
            Datapoint{weight: Weight(5.0), self_affinity: SelfAffinity(5.0)}
        ];
        
        let mut tree = NonIncidentTree::from_data_points(&data_points).unwrap();
        let smallest_coreset_self_affinity = 2.0;
        let mut deleted_indices = HashSet::new();
        let delete_order = vec![0,4,2,1,3];
        for i in delete_order{
            println!("{:?}", "#".repeat(40));
            let index = i.into();
            let _ = tree.delete_node(index).unwrap();
            deleted_indices.insert(index);
            test_total_contribution(&tree, &data_points[..], &deleted_indices , smallest_coreset_self_affinity);
            test_sampling_probabilities(&tree, &data_points, &deleted_indices, smallest_coreset_self_affinity);
        }
    }
}