use std::fmt::{Formatter, Debug};
use crate::d_ary_static_coreset::common::*;
use std::mem::MaybeUninit;
// use std::ops::Range;

// use super::unstable::TreeNode;


#[allow(dead_code)]
pub struct SamplingTree<'a, T, const ARITY: usize>
{
    // Leaves are stored at the end of the storage vector. The root is at index 0.
    pub storage: &'a mut Vec<T>,
    pub leaf_start: usize, // index of the first leaf
}


impl <T, const ARITY:usize> Debug for SamplingTree<'_,T, ARITY>
where T: Debug
{
    fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result {
        match self.storage.first(){
            None =>{
                f.debug_struct("SamplingTree")
                .field("root", &Option::<()>::None)
                .finish()
            },
            Some(root) =>{
                let tree_node = root;
                f.debug_struct("SamplingTree")
                .field("root", tree_node)
                .finish()
            }
        }
    }
}

impl <'a, T, const ARITY:usize> SamplingTree<'a,T, ARITY>
where T: Node<ARITY> + Clone
{   


    pub fn new(vec: &'a mut Vec<T>) -> Self{
        SamplingTree{
            storage: vec,
            leaf_start: 0}
    }

    
    /// Given a 0‑based leaf index in [0..real_leaves),
    /// return its absolute position in `storage`, i.e. leaf_start + idx.
    #[inline(always)]
    pub fn get_shifted_node_index(&self, node_index: Index) -> ShiftedIndex {
        // direct offset by the number of internal nodes
        ShiftedIndex(node_index.0 + self.leaf_start)
    }

    /// Inverse of the above: strip off the leaf_start offset.
    #[inline(always)]
    pub fn get_node_index(&self, shifted_node_index: ShiftedIndex) -> Index {
        Index(shifted_node_index.0 - self.leaf_start)
    }

    pub fn rebuild_from_leaves(&mut self) {
        let internals = self.leaf_start;  // number of internal nodes

        // Bottom‐up: rebuild each parent in reverse order
        for i in (0..internals).rev() {
            // Build an array of Option<&T> for this node’s ARITY children
            let mut kids: [Option<&T>; ARITY] = [None; ARITY];
            for k in 0..ARITY {
                let idx = ARITY * i + 1 + k;
                kids[k] = self.storage.get(idx);
            }
            // Reconstruct the parent from those child refs
            T::from_children(&mut self.storage, i.into());
        }
    }


    pub fn insert_from_iterator<I>(
        &mut self,
        iterator: I,
        min_self_affinity: SelfAffinity,
    ) -> std::ops::Range<ShiftedIndex>
    where
        I: Iterator<Item = (Weight, SelfAffinity)> + ExactSizeIterator,
    {
        // 1) How many real leaves?
        let real_leaves = iterator.len();
        if real_leaves == 0 {
            self.storage.clear();
            self.leaf_start = 0;
            return ShiftedIndex(0)..ShiftedIndex(0);
        }

        // let I be the number of internal nodes
        // let L be the number of leaves
        // let d be the arity of the tree

        // By counting the number of edges in the tree using internal nodes 
        //(which each have d children, except possibly the last one)
        // we have #edges = dI (Case where there are no dangling children)
        // by counting the number of edges using the number of nodes (I+L)
        // we have #edges = (I+L) - 1 (minus 1 for the root)
        // Therefore, in the case where there are no dangling children, we have:
        // dI = (I+L) - 1
        // Rearranging gives us: I = (L-1)/(d-1) + 1
        // To deal with dangling children, we can just round up the number of internal nodes
        // to the next integer, which gives us:
        // I = ceil((L-1)/(d-1))
        let internals = ((real_leaves as f32 - 1.0)/(ARITY as f32 - 1.0)).ceil() as usize;
        // Total nodes = real_leaves + internals
        let tree_len = real_leaves + internals;
        // Leaves start at index = internals
        let shift = internals;

        // Allocate in place
        self.storage.clear();
        assert!(
            self.storage.capacity() >= tree_len,
            "need {} slots, have {}",
            tree_len,
            self.storage.capacity()
        );
        unsafe {
            // uninitialized buffer of length tree_len
            self.storage.set_len(tree_len);
            let ptr = self.storage.as_mut_ptr() as *mut MaybeUninit<T>;

            // Write real leaves
            //    We’ll fill exactly `real_leaves` slots here.
            for (i, (w, sa)) in iterator.enumerate() {
                ptr.add(shift + i).write(MaybeUninit::new(T::new_leaf(
                    w,
                    sa,
                    min_self_affinity,
                )));
            }

            //  Build internals bottom‑up in minimal layout:
            for i in (0..internals).rev() {
                let mut kids: [Option<&T>; ARITY] = [None; ARITY];
                for k in 0..ARITY {
                    let child_idx = ARITY * i + 1 + k;
                    if child_idx < tree_len {
                        // clone the already‑in‑place node
                        let child_ref: &T = (&*ptr.add(child_idx)).assume_init_ref();
                        kids[k] = Some(child_ref);
                    } else {
                        // missing child → zero dummy
                        kids[k] = None;
                    }
                }
                
                let children_indices = T::children_indices(&self.storage, i.into());

                let total_child_weight = children_indices
                    .iter()
                    .map(|idx| self.storage[idx.0].weight())
                    .sum::<Weight>();
                let total_contribution = children_indices
                    .iter()
                    .map(|idx| T::contribution(&self.storage, *idx))
                    .sum::<Contribution>();

                ptr.add(i).write(MaybeUninit::new(
                    T::new_internal(total_child_weight, total_contribution)
                ));
            }
        }

        self.leaf_start = shift;
        // 8) Return the range of *real* leaves
        ShiftedIndex(shift)..ShiftedIndex(shift + real_leaves)
    }

    pub fn sample(&self, rng: &mut impl rand::Rng) -> Result<Index,Error>{
        self._sample(rng, false, Contribution(0.0), Weight(0.0)).map(|(index,_)| index)
    }

    pub fn sample_smoothed(&self, rng: &mut impl rand::Rng, cost: Contribution, coreset_star_weight: Weight) -> Result<(Index,Float),Error>{
        self._sample(rng, true, cost, coreset_star_weight)
    }

    pub fn _sample(&self, rng: &mut impl rand::Rng, smoothed:bool, cost:Contribution, coreset_star_weight:Weight) -> Result<(Index,Float),Error>{
        let shifted_idx_res = T::_sample(&self.storage, rng, smoothed, cost, coreset_star_weight);
        shifted_idx_res.map(|(shifted_idx,prob)|{
            let idx = self.get_node_index(shifted_idx);
            (idx, prob)
        })
    }

    // pub fn compute_sampling_probability(&self, idx: Index) -> Float{
    //     self._computed_sampling_probability(false, idx, Contribution(0.0), Weight(0.0)).unwrap()
    // }

    // pub fn compute_smoothed_sampling_probability(&self, idx: Index, cost: Contribution, coreset_star_weight: Weight) -> Float{
    //     self._computed_sampling_probability(true, idx, cost, coreset_star_weight).unwrap()
    // }

    pub fn _computed_sampling_probability(&self, smoothed: bool, idx: Index, cost: Contribution, coreset_star_weight: Weight) -> Result<Float,Error>{
        let shifted_idx = self.get_shifted_node_index(idx);
        T::_computed_sampling_probability(&self.storage, smoothed, shifted_idx, cost, coreset_star_weight)
    }

    pub fn update_delta(&mut self, idx: Index, new_delta:Delta){
        let shifted_idx = self.get_shifted_node_index(idx);
        assert!(new_delta.0 >= 0.0, "Delta: {} is negative", new_delta.0);
        T::update_delta(&mut self.storage, shifted_idx, new_delta);
    }

}
