use std::fmt::{Formatter, Debug};
use crate::static_coreset::common::*;
use std::mem::MaybeUninit;


#[allow(dead_code)]
pub struct SamplingTree<'a, T>
{
    // Leaves are stored at the end of the storage vector. The root is at index 0.
    pub storage: &'a mut Vec<T>,
}


impl <T> Debug for SamplingTree<'_,T>
where T: Debug
{
    fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result {
        match self.storage.first(){
            None =>{
                f.debug_struct("SamplingTree")
                .field("root", &Option::<()>::None)
                .finish()
            },
            Some(root) =>{
                let tree_node = root;
                f.debug_struct("SamplingTree")
                .field("root", tree_node)
                .finish()
            }
        }
    }
}

impl <'a, T> SamplingTree<'a,T>
where T: Node
{   


    pub fn new(vec: &'a mut Vec<T>) -> Self{
        SamplingTree{
            storage: vec}
    }

    
    #[inline(always)]
    pub fn get_shifted_node_index(&self, node_index: Index) -> ShiftedIndex{
        let storage_size = self.storage.len();
        let num_leaves = (storage_size + 1)/2;
        let shift = num_leaves - 1;
        let shifted_node_index = node_index.0 + shift;
        ShiftedIndex(shifted_node_index)
    }

    #[inline(always)]
    pub fn get_node_index(&self, shifted_node_index: ShiftedIndex) -> Index{
        let storage_size = self.storage.len();
        let num_leaves = (storage_size + 1)/2;
        let shift = num_leaves - 1;
        let node_index = shifted_node_index.0 - shift;
        Index(node_index)
    }

    pub fn rebuild_from_leaves(&mut self){
        let num_nodes = self.storage.len();
        let num_leaves = (num_nodes + 1)/2;
        // Leaves are stored in the last num_leaves elements of the storage vector.
        // We proceed by updating the first num_leaves -1 elements in reverse order.
        (0..num_leaves-1).rev().for_each(|i|{
            let (left_child_idx,right_child_idx) = (2*i+1,2*i+2);
            let left_child_ref = &self.storage[left_child_idx];
            let right_child_ref = &self.storage[right_child_idx];
            self.storage[i] = T::from_children(left_child_ref, right_child_ref);
        });
    }

    pub fn insert_from_iterator<I>(&mut self, mut iterator: I, min_self_affinity:SelfAffinity) ->std::ops::Range<ShiftedIndex>
    where I: Iterator<Item = (Weight,SelfAffinity)> + std::iter::ExactSizeIterator
    {
        // Given an iterator of leaf node data, we create a balanced binary tree in a bottom up fashion.
        // This is to help with cache locality and branch prediction.

        // We will fill up an array with the nodes. Leafs will be stored at the end of the array.

        // The total number of nodes in a binary tree with num_leaves is 2*num_leaves - 1.
        // given an index i, the left child is 2*i + 1 and the right child is 2*i + 2.
        // The parent of a node at index i is (i-1)/2.
        let num_leaves = iterator.len();
        
        if num_leaves == 0{
            self.storage.clear();
            return ShiftedIndex(0)..ShiftedIndex(0);
        }
        let tree_len = 2*num_leaves - 1;

        // clear the buffer but don't deallocate:
        self.storage.clear();

        assert!(
            self.storage.capacity() >= tree_len,
            "The storage vector is not large enough. It has a capacity of {} but we need {}",
            self.storage.capacity(),
            tree_len
        );

        let ptr = self.storage.as_mut_ptr() as *mut MaybeUninit<T>;

        unsafe{
            // pretend the vector has length tree_len so we can fill it up
            self.storage.set_len(tree_len);

            // write the leaves first:
            for (i, (weight, self_affinity)) in iterator.enumerate(){
                let shifted_index = ShiftedIndex(num_leaves - 1 + i);
                let node = T::new(weight, self_affinity, min_self_affinity);
                ptr.add(shifted_index.0).write(MaybeUninit::new(node));
            }
            // write the rest of the nodes bottom up:
            for i in (0..num_leaves-1).rev(){
                ptr.add(i).write(MaybeUninit::new(T::from_children(
                    ptr.add(2*i + 1).read().assume_init_ref(),
                    ptr.add(2*i + 2).read().assume_init_ref())));
            }

        }
        return ShiftedIndex(num_leaves - 1)..ShiftedIndex(tree_len);
    }

    pub fn sample(&self, rng: &mut impl rand::Rng) -> Result<Index,Error>{
        self._sample(rng, false, Contribution(0.0), Weight(0.0)).map(|(index,_)| index)
    }

    pub fn sample_smoothed(&self, rng: &mut impl rand::Rng, cost: Contribution, coreset_star_weight: Weight) -> Result<(Index,Float),Error>{
        self._sample(rng, true, cost, coreset_star_weight)
    }

    pub fn _sample(&self, rng: &mut impl rand::Rng, smoothed:bool, cost:Contribution, coreset_star_weight:Weight) -> Result<(Index,Float),Error>{
        let shifted_idx_res = T::_sample(&self.storage, rng, smoothed, cost, coreset_star_weight);
        shifted_idx_res.map(|(shifted_idx,prob)|{
            let idx = self.get_node_index(shifted_idx);
            (idx, prob)
        })
    }

    // pub fn compute_sampling_probability(&self, idx: Index) -> Float{
    //     self._computed_sampling_probability(false, idx, Contribution(0.0), Weight(0.0)).unwrap()
    // }

    // pub fn compute_smoothed_sampling_probability(&self, idx: Index, cost: Contribution, coreset_star_weight: Weight) -> Float{
    //     self._computed_sampling_probability(true, idx, cost, coreset_star_weight).unwrap()
    // }

    pub fn _computed_sampling_probability(&self, smoothed: bool, idx: Index, cost: Contribution, coreset_star_weight: Weight) -> Result<Float,Error>{
        let shifted_idx = self.get_shifted_node_index(idx);
        T::_computed_sampling_probability(&self.storage, smoothed, shifted_idx, cost, coreset_star_weight)
    }

    pub fn update_delta(&mut self, idx: Index, new_delta:Delta){
        let shifted_idx = self.get_shifted_node_index(idx);
        assert!(new_delta.0 >= 0.0, "Delta: {} is negative", new_delta.0);
        T::update_delta(&mut self.storage, shifted_idx, new_delta);
    }

}
