
use std::collections::HashSet;
use ordered_float::OrderedFloat;
use rand::{rngs::ThreadRng, Rng};

// import common data types from the sibling module
use crate::common::{CoresetCrossTerm, DatapointWithCoresetCrossTerm, Error, Float, Index, SelfAffinity, Weight, FxHashSet};






// MARK: -Node structs

#[derive(Debug, Clone)]
pub struct IncidentLeafNode{
    index: Index,
    weight: Weight,
    self_affinity: SelfAffinity,
    coreset_cross_term: CoresetCrossTerm,
}

impl IncidentLeafNode{
    pub fn contribution(&self) -> Float{
        self.weight.0*(self.self_affinity.0 + self.coreset_cross_term.0)
    }
    pub fn smoothed_contribution(&self, cost: Float, coreset_star_weight: Float) -> Float{
        self.contribution()/cost + self.weight.0/coreset_star_weight
    }
    pub fn index(&self) -> Index{
        self.index
    }
    pub fn weight(&self) -> Weight{
        self.weight
    }
    pub fn self_affinity(&self) -> SelfAffinity{
        self.self_affinity
    }
}


#[derive(Debug)]
struct InternalNode{
    left_indices: FxHashSet<Index>,
    right_indices: FxHashSet<Index>,
    left_subtree: Option<Box<TreeNode>>,
    right_subtree: Option<Box<TreeNode>>,
    contribution: Float,
    weight: Weight
}

impl InternalNode{
    fn smoothed_contribution(&self, cost: Float, coreset_star_weight: Float) -> Float{
        self.contribution/cost + self.weight.0/coreset_star_weight
    }
}

// MARK: -Node enum
#[derive(Debug)]
enum TreeNode{
    Leaf(IncidentLeafNode),
    Internal(InternalNode)
}

impl TreeNode{
    fn contribution(&self) -> Float{
        match self{
            TreeNode::Leaf(leaf) => leaf.contribution(),
            TreeNode::Internal(InternalNode{contribution, ..}) => *contribution
        }
    }
    fn type_str(&self) -> String{
        match self{
            TreeNode::Leaf(_) => "Leaf".to_string(),
            TreeNode::Internal(_) => "Internal".to_string()
        }
    }
    fn smoothed_contribution(&self, cost: Float, coreset_star_weight: Float) ->Float{
        match self{
            TreeNode::Leaf(leaf) => leaf.smoothed_contribution(cost, coreset_star_weight),
            TreeNode::Internal(internal) => internal.smoothed_contribution(cost, coreset_star_weight) 
        }
    }

}


// MARK: -Tree struct

#[derive(Debug)]
pub struct IncidentTree{
    root: Option<TreeNode>
}


impl IncidentTree{
    
    pub fn contribution(&self) -> Float{
        self.root.as_ref().map_or(0.0, |r| r.contribution())
    }

    pub fn smoothed_contribution(&self, cost: Float, coreset_star_weight: Float) -> Float{
        self.root.as_ref().map_or(0.0, |r| r.smoothed_contribution(cost, coreset_star_weight))
    }

    pub fn new() -> Self{
        Self{
            root: None
        }
    }

    
    // MARK: -Insert
    pub fn insert_node(&mut self, datapoint: DatapointWithCoresetCrossTerm, index: Index) -> Result<(), Error>{
        // println!("{:?}", "#".repeat(40));
        // println!("Inserting node with index: {}", index.0);
        
        match &mut self.root{
            None => {
                let leaf = IncidentLeafNode{
                    index,
                    weight: datapoint.weight,
                    self_affinity: datapoint.self_affinity,
                    coreset_cross_term: datapoint.coreset_cross_term
                };
                self.root = Some(TreeNode::Leaf(leaf));
                Ok(())
            },
            Some(tree_node) => {
                let res = IncidentTree::_insert_node(tree_node, datapoint, index);
                match res{
                    Ok(Some(internal_node)) => {
                        // This should only happen when the root node was a leaf node and
                        // we're promoting it to an internal node.
                        assert!(tree_node.type_str() == "Leaf");
                        self.root = Some(TreeNode::Internal(internal_node));
                        Ok(())
                    },
                    Ok(None) => {
                        Ok(())
                    },
                    Err(e) => Err(e)
                }
            }
        }
    }
    fn _insert_node(tree_node: &mut TreeNode, datapoint: DatapointWithCoresetCrossTerm, index: Index) -> Result<Option<InternalNode>, Error>{
        // If the tree_node is a leaf node, then we create a new internal node with the leaf_node as the left child and the 
        // new datapoint as the right child and pass it up the tree.
        // If tree_node is an internal node, we call _insert_node on the smaller child.
        // println!("inserting node {:?} into {:?}", index, tree_node);
        match tree_node{
            TreeNode::Leaf(leaf) =>{
                // we need to create a new internal node with this node and the new datapoint as children,
                // and pass it up the tree.
                let right_leaf = IncidentLeafNode{
                    index,
                    weight: datapoint.weight,
                    self_affinity: datapoint.self_affinity,
                    coreset_cross_term: datapoint.coreset_cross_term
                };
                let left_indices: FxHashSet<Index> = vec![leaf.index].into_iter().collect();
                let right_indices: FxHashSet<Index> = vec![index].into_iter().collect();
                let (right_leaf_weight, right_leaf_contribution) = (right_leaf.weight.clone(), right_leaf.contribution());
                let contribution = leaf.contribution() + right_leaf_contribution;
                let weight = leaf.weight + right_leaf_weight;
                
                // We have to clone leaf because it's behind a mutable reference.
                let internal_node = InternalNode{
                    left_indices,
                    right_indices,
                    left_subtree: Some(Box::new(TreeNode::Leaf(leaf.clone()))),
                    right_subtree: Some(Box::new(TreeNode::Leaf(right_leaf))),
                    contribution: contribution,
                    weight: weight
                };
                Ok(Some(internal_node))
            },
            TreeNode::Internal(internal) =>{
                // we're at an internal node and need to insert the datapoint into the tree.
                let (smaller_subtree, smaller_subtree_indices) = match internal.left_indices.len() <= internal.right_indices.len(){
                    true => (internal.left_subtree.as_mut().unwrap(), &mut internal.left_indices),
                    false => (internal.right_subtree.as_mut().unwrap(), &mut internal.right_indices)
                };
                // println!("Inserting into smaller subtree: {:?}", smaller_subtree_indices);
                let res = IncidentTree::_insert_node(smaller_subtree, datapoint, index);

                if let Ok(Some(new_internal)) = res{
                    // if the smaller subtree was a leaf node, then we need to replace it with the new internal node
                    *smaller_subtree = Box::new(TreeNode::Internal(new_internal));

                }
                // update the contribution/ weight of the internal node
                internal.contribution += datapoint.contribution();
                internal.weight = internal.weight + datapoint.weight;
                // update the smaller subtree indices:
                smaller_subtree_indices.insert(index);
                
                Ok(None)            
            }
        }
    }

    pub fn update_node_coreset_cross_term(&mut self, index: Index, coreset_cross_term: CoresetCrossTerm) -> Result<(), Error>{
        // update the coreset cross term of the node with the given index if the index is present and the 
        // incoming corest cross term is smaller than the current one. Bubble the difference up the tree.
        match &mut self.root{
            None => Err(Error::NodeNotFound(index)),
            Some(tree_node) => {
                let _ = IncidentTree::_update_node_corest_cross_term(tree_node, index, coreset_cross_term)?;
                Ok(())
            }
        }
    }

    fn _update_node_corest_cross_term(tree_node: &mut TreeNode, index: Index, coreset_cross_term: CoresetCrossTerm) -> Result<Float, Error>{
        match tree_node{
            TreeNode::Leaf(leaf) => {
                if leaf.index == index{
                    // Jump through the hoops of converting to and from OrderedFloat.
                    let new_cross_term: CoresetCrossTerm = OrderedFloat::min(leaf.coreset_cross_term.0.into(), coreset_cross_term.0.into()).0.into();
                    let contribution_diff = (new_cross_term.0 - leaf.coreset_cross_term.0)*leaf.weight.0;
                    leaf.coreset_cross_term = new_cross_term;
                    Ok(contribution_diff)
                } else{
                    Err(Error::NodeNotFound(index))
                }
            },
            TreeNode::Internal(node) =>{
                let target_subtree = {
                    if node.left_indices.contains(&index){
                        node.left_subtree.as_mut()
                    } else if node.right_indices.contains(&index){
                        node.right_subtree.as_mut()
                    } else{
                        return Err(Error::NodeNotFound(index))
                    }
                };
                let contribution_diff = IncidentTree::_update_node_corest_cross_term(target_subtree.unwrap(), index, coreset_cross_term)?;
                node.contribution += contribution_diff;
                Ok(contribution_diff)
            }
        }    
    }



    pub fn delete_node(&mut self, index: Index) -> Result<Float, Error>{
        // find and delete the node with the given index, updating the tree accordingly.
        match &mut self.root{
            None => Err(Error::NodeNotFound(index)),
            Some(tree_node) => {
                let (contribution, now_empty, _) = IncidentTree::_delete_node(tree_node, index)?;
                match now_empty{
                    true => {
                        self.root = None;
                        Ok(contribution)
                    },
                    false => Ok(contribution)
                }
            }
        }
    }
    fn _delete_node_from_internal(internal: &mut InternalNode, index: Index) -> Result<(Float, bool,bool),Error>{
        match (internal.left_indices.contains(&index),internal.right_indices.contains(&index)){
            (false, false) => Err(Error::NodeNotFound(index)),
            (true, true) => Err(Error::NodeInBothSubtrees(index)),
            (_, _) =>{
                let (target_subtree,target_indices,other_subtree,target_side) = match internal.left_indices.contains(&index){
                    true => (&mut internal.left_subtree, &mut internal.left_indices, internal.right_subtree.as_mut(), false),
                    false => (&mut internal.right_subtree, &mut internal.right_indices, internal.left_subtree.as_mut(), true)
                };
                let (contribution, target_now_empty, _) = IncidentTree::_delete_node(&mut target_subtree.as_mut().unwrap(), index)?;
                match (other_subtree.is_none(), target_now_empty){
                    (true, true) => {
                        // both subtrees are now empty, so we can signal to the parent that this node should be deleted.
                        Ok((contribution, true,target_side))
                    },
                    (_, false) =>{
                        // the target subtree is still not empty so we just update the contribution and pass it up the tree.
                        internal.contribution -= contribution;
                        target_indices.remove(&index);
                        Ok((contribution, false,target_side))
                    },
                    (false, true) =>{
                        // the target subtree is now empty, but the other subtree is not.
                        // We delete the target subtree, update the contribution and indices, and pass it up the tree.
                        *target_subtree = None;
                        target_indices.remove(&index);
                        internal.contribution -= contribution;
                        Ok((contribution, false,target_side))
                    }
                }

            }
        }
    }

    fn _delete_node(treenode: &mut TreeNode, target_index: Index) -> Result<(Float, bool,bool),Error>{
        // This function returns the contribution of the deleted node and a boolean indicating to the parent node
        // if this subtree is now empty.
        match treenode{
            TreeNode::Leaf(leaf) => {
                if leaf.index == target_index{
                    Ok((leaf.contribution(), true,false))
                }else{
                    Err(Error::NodeNotFound(target_index))
                }
            },
            TreeNode::Internal(internal) => {
                match IncidentTree::_delete_node_from_internal(internal, target_index){
                    Ok((contribution, now_empty, target_side)) => Ok((contribution, now_empty,target_side)),
                    Err(e) => Err(e)
                }
            },
        }
    }

    pub fn sample_node<'a>(&'a self, rng: &mut ThreadRng) -> Result<&'a IncidentLeafNode, Error>{

        match &self.root{
            None => Err(Error::EmptyTree),
            Some(tree_node) => {
                let mut node_ref = tree_node;
                while let TreeNode::Internal(node) = node_ref{
                    let lhs_contribution = node.left_subtree.as_ref().map_or(0.0, |l| l.contribution());
                    let rhs_contribution = node.right_subtree.as_ref().map_or(0.0, |r| r.contribution());
                    let total_contribution = lhs_contribution + rhs_contribution;
                    let sample = rng.gen::<Float>() * total_contribution;
                    if sample <= lhs_contribution{
                        node_ref = node.left_subtree.as_ref().unwrap();
                    }else{
                        node_ref = node.right_subtree.as_ref().unwrap();
                    }
                }
                let leaf = match node_ref{
                    TreeNode::Leaf(leaf) => leaf,
                    TreeNode::Internal(_) => unreachable!()
                };
                Ok(&leaf)
            }
        }
    }

    pub fn sample_node_smoothed(&self,cost:Float, coreset_star_weight:Float, rng: &mut ThreadRng) -> Result<(&IncidentLeafNode, Float), Error>{
        let mut prob = 1.0;
        match &self.root{
            None => Err(Error::EmptyTree),
            Some(tree_node) => {
                let mut node_ref = tree_node;
                while let TreeNode::Internal(node) = node_ref{
                    let lhs_contribution = node.left_subtree.as_ref().map_or(0.0, |l| l.smoothed_contribution(cost,coreset_star_weight));
                    let rhs_contribution = node.right_subtree.as_ref().map_or(0.0, |r| r.smoothed_contribution(cost,coreset_star_weight));
                    let total_contribution = lhs_contribution + rhs_contribution;
                    let sample = rng.gen::<Float>() * total_contribution;
                    if sample <= lhs_contribution{
                        prob *= lhs_contribution/total_contribution;
                        node_ref = node.left_subtree.as_ref().unwrap();
                    }else{
                        prob *= rhs_contribution/total_contribution;
                        node_ref = node.right_subtree.as_ref().unwrap();
                    }
                }
                let leaf = match node_ref{
                    TreeNode::Leaf(leaf) => leaf,
                    TreeNode::Internal(_) => unreachable!()
                };
                Ok((leaf, prob))
            }
        }
    }
    #[allow(dead_code)]
    fn compute_sampling_probability(&self, index: Index) -> Result<Float, Error>{
        // given that we sample a node from the tree, compute the conditional probability of sampling the index'th node.
        match &self.root{
            None => return Ok(0.0),
            Some(tree_node) => {
                let mut node_ref = tree_node;
                let mut prob = 1.0;
                while let TreeNode::Internal(node) = node_ref{
                    let target_subtree = {
                        if node.left_indices.contains(&index){
                            node.left_subtree.as_ref()        
                        } else if node.right_indices.contains(&index){
                            node.right_subtree.as_ref()
                        } else{
                            return Ok(0.0)
                        }
                    };
                    let target_contribution = target_subtree.map_or(0.0, |t| t.contribution());
                    prob *= target_contribution/node.contribution;
                    node_ref = target_subtree.unwrap().as_ref();
                };
                match node_ref{
                    TreeNode::Leaf(leaf) => {
                        if leaf.index == index{
                            return Ok(prob)
                        }else{
                            return Ok(0.0)
                        }
                    },
                    TreeNode::Internal(_) => unreachable!()
                }

            }
        }
    }

    #[allow(dead_code)]
    fn compute_smoothed_sampling_probability(&self, index: Index, cost: Float, coreset_star_weight: Float) -> Result<Float, Error>{
        // given that we sample a node from the tree, compute the conditional probability of sampling the index'th node.

        match &self.root{
            None => return Ok(0.0),
            Some(tree_node) => {
                let mut node_ref = tree_node;
                let mut prob = 1.0;
                while let TreeNode::Internal(node) = node_ref{
                    let target_subtree = {
                        if node.left_indices.contains(&index){
                            node.left_subtree.as_ref()        
                        } else if node.right_indices.contains(&index){
                            node.right_subtree.as_ref()
                        } else{
                            return Ok(0.0)
                        }
                    };
                    let target_contribution = target_subtree.map_or(0.0, |t| t.smoothed_contribution(cost,coreset_star_weight));
                    prob *= target_contribution/node.smoothed_contribution(cost,coreset_star_weight);
                    node_ref = target_subtree.unwrap().as_ref();
                };
                match node_ref{
                    TreeNode::Leaf(leaf) => {
                        if leaf.index == index{
                            return Ok(prob)
                        }else{
                            return Ok(0.0)
                        }
                    },
                    TreeNode::Internal(_) => unreachable!()
                }
            }
        }
    }


}

#[allow(dead_code)]
fn compute_actual_total_contribution(data_points: &[DatapointWithCoresetCrossTerm], indices_present: &HashSet<Index>) -> Float{
    (0..data_points.len()).map(|i|{
        match indices_present.contains(&Index(i)){
            true => data_points[i].contribution(),
            false => 0.0
        }
    }).sum::<Float>()
}

#[allow(dead_code)]
fn compute_actual_total_smoothed_contribution(data_points: &[DatapointWithCoresetCrossTerm], indices_present: &HashSet<Index>, cost: Float, coreset_star_weight: Float) -> Float{
    (0..data_points.len()).map(|i|{
        match indices_present.contains(&Index(i)){
            true => data_points[i].smoothed_contribution(cost, coreset_star_weight.into()),
            false => 0.0
        }
    }).sum::<Float>()
}

#[allow(dead_code)]
fn test_total_contribution(tree: &IncidentTree, data_points: &[DatapointWithCoresetCrossTerm], indices_present: &HashSet<Index>){
    let total_contribution = compute_actual_total_contribution(data_points, indices_present);
    println!("Actual vs tree total contribution: {} vs {}", total_contribution, tree.contribution());
    assert!((total_contribution - tree.contribution()).abs() < 1e-6);
}

#[allow(dead_code)]
fn test_sampling_probabilities(tree: &IncidentTree, data_points: &[DatapointWithCoresetCrossTerm], indices_present: &HashSet<Index>){
    let total_contribution = compute_actual_total_contribution(data_points, indices_present);

    let target_probs = (0..data_points.len()).map(|i|{
        match indices_present.contains(&i.into()){
            true => data_points[i].contribution()/total_contribution,
            false => 0.0
        }
    }).collect::<Vec<Float>>();

    let tree_probs = (0..data_points.len()).map(|i|{
        tree.compute_sampling_probability(i.into()).unwrap()
    }).collect::<Vec<Float>>();

    println!("Target probs: {:?}", target_probs);
    println!("Tree probs: {:?}", tree_probs);

    for (target_prob, tree_prob) in target_probs.iter().zip(tree_probs.iter()){
        assert!((target_prob - tree_prob).abs() < 1e-6);
    }
}

#[allow(dead_code)]
fn test_smoothed_sampling_probabilities(tree: &IncidentTree, data_points: &[DatapointWithCoresetCrossTerm], indices_present: &HashSet<Index>, cost: Float, coreset_star_weight: Float){
    let total_contribution = compute_actual_total_smoothed_contribution(data_points, indices_present, cost, coreset_star_weight);

    let target_probs = (0..data_points.len()).map(|i|{
        match indices_present.contains(&i.into()){
            true => data_points[i].smoothed_contribution(cost, coreset_star_weight.into())/total_contribution,
            false => 0.0
        }
    }).collect::<Vec<Float>>();

    let tree_probs = (0..data_points.len()).map(|i|{
        tree.compute_smoothed_sampling_probability(i.into(), cost, coreset_star_weight).unwrap()
    }).collect::<Vec<Float>>();

    println!("Target probs: {:?}", target_probs);
    println!("Tree probs: {:?}", tree_probs);

    for (target_prob, tree_prob) in target_probs.iter().zip(tree_probs.iter()){
        assert!((target_prob - tree_prob).abs() < 1e-6);
    }
}


#[cfg(test)]
mod tests{
    use super::*;
    #[test]
    fn test_incident_tree_sample(){
        let mut data_points = vec![
            DatapointWithCoresetCrossTerm{weight: Weight(1.0), self_affinity: SelfAffinity(1.0), coreset_cross_term: CoresetCrossTerm(1.0)},
            DatapointWithCoresetCrossTerm{weight: Weight(2.0), self_affinity: SelfAffinity(2.0), coreset_cross_term: CoresetCrossTerm(2.0)},
            DatapointWithCoresetCrossTerm{weight: Weight(3.0), self_affinity: SelfAffinity(3.0), coreset_cross_term: CoresetCrossTerm(3.0)},
            DatapointWithCoresetCrossTerm{weight: Weight(4.0), self_affinity: SelfAffinity(4.0), coreset_cross_term: CoresetCrossTerm(4.0)},
            DatapointWithCoresetCrossTerm{weight: Weight(5.0), self_affinity: SelfAffinity(5.0), coreset_cross_term: CoresetCrossTerm(5.0)}
        ];
    
        let mut tree = IncidentTree::new();
        let order_to_add = vec![0,1,2,3,4];
        let order_to_delete = vec![2,3,1,0,4];
    
        let mut indices_present = HashSet::new();
        
        println!("testing insertion");
        for i in order_to_add.iter(){
            tree.insert_node(data_points[*i], (*i).into()).unwrap();
            println!(" gives {:?}", tree);
            indices_present.insert((*i).into());
            test_total_contribution(&tree, &data_points, &indices_present);
            test_sampling_probabilities(&tree, &data_points, &indices_present);
            test_smoothed_sampling_probabilities(&tree, &data_points, &indices_present, 1.0, 15.0);
        }

        println!("testing updates");
        // Now test updating the coreset cross terms
        for i in order_to_add{
            let new_cross_term = CoresetCrossTerm(1.0/((i+1) as Float));
            data_points[i].coreset_cross_term = new_cross_term;
            tree.update_node_coreset_cross_term(i.into(), new_cross_term).unwrap();
            test_total_contribution(&tree, &data_points, &indices_present);
            test_sampling_probabilities(&tree, &data_points, &indices_present);
            test_smoothed_sampling_probabilities(&tree, &data_points, &indices_present, 1.0, 15.0);
        }

        println!("testing deletion");
        for i in order_to_delete{
            println!("{:?}", "#".repeat(40));
            tree.delete_node(i.into()).unwrap();
            indices_present.remove(&i.into());
            test_total_contribution(&tree, &data_points, &indices_present);
            test_sampling_probabilities(&tree, &data_points, &indices_present);
        }
    }
}
