use crate::faster::DynamicCorest;
use crate::faster::common::*;
use crate::faster::TreeData;

use std::ops::{AddAssign};
use rustc_hash::{FxHashMap};


impl<const ARITY: usize> DynamicCorest<ARITY> 
where ConstPow2<ARITY>: PowerOfTwo
{

    // MARK: Tree Methods
    #[inline(always)]
    pub fn child_index(&self, parent_index: ShiftedIndex, child_index: usize) -> Option<ShiftedIndex>{
        let child_shifted_index = (parent_index.0 << Self::ARITY_SHIFT) + child_index + 1; // +1 because the first node is the root
        if child_shifted_index < self.num_nodes_in_tree() {
            Some(ShiftedIndex(child_shifted_index))
        } else {
            None
        }
    }

    #[inline(always)]
    fn parent_index(child_index: ShiftedIndex) -> Option<ShiftedIndex>{
        if child_index.0 == 0 {
            // The root node has no parent
            return None;
        }
        let parent_shifted_index = (child_index.0 - 1) >> Self::ARITY_SHIFT;
        Some(ShiftedIndex(parent_shifted_index))
    }

    #[inline(always)]
    fn which_child(&self, node_index: ShiftedIndex) -> usize{
        // return which child this node is of it's parent
        (node_index.0 - 1) % ARITY
    }

    #[inline(always)]
    pub fn num_nodes_in_tree(&self) -> usize {
        let n = self.tree_data.sizes.len();
        debug_assert!(
            n == self.tree_data.volumes.len() &&
            n == self.tree_data.inv_volumes.len() &&
            n == self.tree_data.deltas.len() &&
            n == self.tree_data.smoothing_term_deltas.len() &&
            n == self.tree_data.generational_flags.len(),
            "Inconsistent tree node counts: \
             sizes: {}, volumnes: {}, inv_volumnes: {}, deltas: {}, smoothing_term_deltas: {}, generational_flags: {}",
            n,
            self.tree_data.volumes.len(),
            self.tree_data.inv_volumes.len(),
            self.tree_data.deltas.len(),
            self.tree_data.smoothing_term_deltas.len(),
            self.tree_data.generational_flags.len()
        );
        n
    }

    #[inline(always)]
    pub fn promote_leaf_to_internal(&mut self, idx: ShiftedIndex){
        // Copy the leaf node's data to a new slot:
        debug_assert!(
            self.child_index(idx, 0).is_none() &&
            self.tree_data.sizes[idx.0] == 1,
            "Can only promote a leaf node that is about to become an internal node."
        );
        
        // Remove the node from the maps:
        let (_, node_id) = self
            .node_location_map_reverse
            .remove_entry(&idx).unwrap();
        self.node_location_map.remove(&node_id);

        let new_index = ShiftedIndex(self.num_nodes_in_tree());
        // Insert the node into the maps at the new index:
        self.node_location_map.insert(node_id, new_index);
        self.node_location_map_reverse.insert(new_index, node_id);

        self.tree_data.sizes.push(1);
        self.tree_data.volumes.push(self.tree_data.volumes[idx.0]);
        self.tree_data.inv_volumes.push(self.tree_data.inv_volumes[idx.0]);

        // We can just set the delta to 0.0, since we're not querying right now.
        self.tree_data.deltas.push(Delta::new(0.0.into()));
        // same with the smoothing term deltas:
        self.tree_data.smoothing_term_deltas.push(SmoothingTermDelta::new(0.0.into()));
        // Same goes for the delta generational flag:
        self.tree_data.generational_flags.push(0);

    }
    
    pub fn insert_node_into_tree(&mut self,node: NodeIdentity, weight: EdgeWeight) {
    
    // When we insert, we implicitly assume a self loop with weight 1.0.

        debug_assert!(
            !self.node_location_map.contains_key(&node),
            "Node {} already exists in the tree.",
            node
        );


        // first check if we need to promote the parent of the new node to a leaf:
        if let Some(pidx) = Self::parent_index(ShiftedIndex(self.num_nodes_in_tree())){
            // If the parent is a leaf, we need to promote it to an internal node:
            if self.tree_data.sizes[pidx.0] == 1{
                self.promote_leaf_to_internal(pidx);
            }
        }

        // Insert a fresh node into the tree:
        self.node_location_map.insert(node, self.num_nodes_in_tree().into());
        self.node_location_map_reverse.insert(
            ShiftedIndex(self.num_nodes_in_tree()),
            node
        );

        self.tree_data.sizes.push(1);
        self.tree_data.volumes.push(convert(EdgeWeight::new(1.0.into()) + weight));
        self.tree_data.inv_volumes.push(convert((EdgeWeight::new(1.0.into()) + weight).inv()));
        self.tree_data.deltas.push(Delta::new(0.0.into()));
        self.tree_data.smoothing_term_deltas.push(SmoothingTermDelta::new(0.0.into()));
        self.tree_data.generational_flags.push(self.node_generation_counter);

        // Then we propagate the changes up the tree:
        let child_index = ShiftedIndex(self.num_nodes_in_tree() - 1);
        Self::propagate_insert_up_tree(&mut self.tree_data, child_index);
    }

    #[inline(always)]
    pub fn propagate_insert_up_tree(
        tree_data: &mut TreeData<ARITY>,
        child_index: ShiftedIndex,
    ){
        let child_volume_increase = tree_data.volumes[child_index.0];
        let child_inv_volume_increase = tree_data.inv_volumes[child_index.0];
        
        // walk up the tree to the root, updating sizes, volumes and deltas:
        // Don't need to touch the delta or delta generational flags. They are only used at query time.
        Self::propagate_up_with_closure(tree_data, child_index, |this, parent_index| {
            this.sizes[parent_index.0] += 1;
            this.volumes[parent_index.0] += child_volume_increase;
            this.inv_volumes[parent_index.0] += child_inv_volume_increase;
        });
    }

    #[inline(always)]
    pub fn propagate_up_with_closure<F>(tree_data: &mut TreeData<ARITY>, child_index: ShiftedIndex, mut closure: F)
    where
        F: FnMut(&mut TreeData<ARITY>, ShiftedIndex),
    {
        let mut current_index = Self::parent_index(child_index);
        while let Some(parent_index) = current_index {
            closure(tree_data,parent_index);
            current_index = Self::parent_index(parent_index);
        }
    }

    #[inline(always)]
    pub fn swap_nodes_in_tree(
        &mut self,
        node_a: NodeIdentity,
        node_b: NodeIdentity){
        // Ensure both nodes exist in the tree:
        debug_assert!(
            self.node_location_map.contains_key(&node_a) && self.node_location_map.contains_key(&node_b),
            "One or both nodes do not exist in the tree: {} and {}",
            node_a, node_b
        );
        // Get their locations in the tree:
        let index_a = self.node_location_map.get(&node_a).unwrap().clone();
        let index_b = self.node_location_map.get(&node_b).unwrap().clone();
        // swap their locations in the vecs:
        self.tree_data.sizes.swap(index_a.0, index_b.0);
        self.tree_data.volumes.swap(index_a.0, index_b.0);
        self.tree_data.inv_volumes.swap(index_a.0, index_b.0);

        // Don't need to swap deltas or generational flags, since they are only used at query time.
        // self.tree_data.deltas.swap(index_a.0, index_b.0);
        // self.tree_data.smoothing_term_deltas.swap(index_a.0, index_b.0);
        // self.tree_data.generational_flags.swap(index_a.0, index_b.0);
        // Update the node_location_map:
        self.node_location_map.insert(node_a, index_b);
        self.node_location_map.insert(node_b, index_a);
        // Update the node_location_map_reverse:
        self.node_location_map_reverse.insert(index_a, node_b);
        self.node_location_map_reverse.insert(index_b, node_a);
        
    }

    pub fn delete_node_from_tree(&mut self, node_to_delete: NodeIdentity){
        // Ensure the node exists in the tree:
        debug_assert!(
            self.node_location_map.contains_key(&node_to_delete),
            "Node {} does not exist in the tree.",
            node_to_delete
        );

        // Get the index of the node to delete:
        let index_to_delete = self.node_location_map.get(&node_to_delete).unwrap().clone();

        // swap the node to delete with the last node in the tree:
        let last_node_index = ShiftedIndex(self.num_nodes_in_tree() - 1);


        // store the volume and inv_volume of the last node to remove from ancestors after the swap:
        let volume_to_remove_from_ancestors = self.tree_data.volumes[last_node_index];
        let inv_volume_to_remove_from_ancestors = self.tree_data.inv_volumes[last_node_index];

        if index_to_delete != last_node_index {
            // Get the last node's identity:
            let last_node_identity = self.node_location_map_reverse.get(&last_node_index).unwrap().clone();
            // Swap the nodes in the tree:
            self.swap_nodes_in_tree(node_to_delete, last_node_identity);

            // Now we need to work out the difference in volume and inv_volume to propagate up the tree from the index index_to_delete (which is now the last node):
            // Sizes don't change, so we don't need to update them.
            let volume_diff = self.tree_data.volumes[last_node_index] - self.tree_data.volumes[index_to_delete.0];
            let inv_volume_diff = self.tree_data.inv_volumes[last_node_index] - self.tree_data.inv_volumes[index_to_delete.0];
            
            Self::propagate_up_with_closure(&mut self.tree_data,index_to_delete, |this, parent_index|{
                this.volumes[parent_index.0] -= volume_diff;
                this.inv_volumes[parent_index.0] -= inv_volume_diff;
            });
        }

        // Now we can remove the last node and bubble changes in size, volume and inv_volume up the tree:
        Self::propagate_up_with_closure(&mut self.tree_data, last_node_index, |this,parent_index|{
            this.sizes[parent_index.0] -= 1;
            this.volumes[parent_index.0] -= volume_to_remove_from_ancestors;
            this.inv_volumes[parent_index.0] -= inv_volume_to_remove_from_ancestors;
        });

        // Finally, we can pop the last node from the vecs
        self.tree_data.sizes.pop();
        self.tree_data.volumes.pop();
        self.tree_data.inv_volumes.pop();
        self.tree_data.deltas.pop();
        self.tree_data.smoothing_term_deltas.pop();
        self.tree_data.generational_flags.pop();

        // And remove the node from the maps:
        self.node_location_map.remove(&node_to_delete);
        self.node_location_map_reverse.remove(&last_node_index);


        // The last thing we check is that the new final node is not an orphan. If it is, we delete it and point the maps to it's parent.
        
        if self.num_nodes_in_tree() == 0 {
            // If the tree is empty, we don't need to check for orphan nodes.
            return;
        }

        let maybe_orphan_idx = ShiftedIndex(self.num_nodes_in_tree()-1);
        if self.which_child(maybe_orphan_idx) == 0{
            let parent_index = Self::parent_index(maybe_orphan_idx).unwrap();
            
            // update the maps:
            let orphan_identity = self.node_location_map_reverse.remove(&maybe_orphan_idx).unwrap();
            self.node_location_map.insert(orphan_identity, parent_index);
            self.node_location_map_reverse.insert(parent_index, orphan_identity);

            // now we pop the last node from the vecs:
            self.tree_data.sizes.pop();
            self.tree_data.volumes.pop();
            self.tree_data.inv_volumes.pop();
            self.tree_data.deltas.pop();
            self.tree_data.smoothing_term_deltas.pop();
            self.tree_data.generational_flags.pop();
        }
    }



}