use super::common::*;
use std::fmt::{Formatter, Debug};





// #[repr(C, align(64))]
#[derive(Debug, Clone)]
#[allow(dead_code)]
pub struct TreeNode< const ARITY: usize>{
    pub weight: Weight,
    pub contribution_or_delta: ContributionOrDelta,
}

#[allow(dead_code)]
impl <const ARITY: usize> TreeNode<ARITY>{

    #[inline(always)]
    fn is_leaf(storage_len: usize, shifted_idx: ShiftedIndex) -> bool {
        // is the first child of the node after the end of the storage?
        // if so, the node is a leaf, as long as the node itself exists
        let first_child_index = Self::child_idx(shifted_idx, 0);
        first_child_index.0 >= storage_len
        && shifted_idx.0 < storage_len
    }

    #[inline(always)]
    fn is_internal(storage_len: usize, shifted_idx: ShiftedIndex) -> bool {
        // is the first child of the node before the end of the storage?
        // if so, the node is an internal node, // as long as the node itself exists
        let first_child_index = Self::child_idx(shifted_idx, 0);
        first_child_index.0 < storage_len
        && shifted_idx.0 < storage_len
    }

    #[inline(always)]
    pub fn contribution_as_leaf(&self) -> Contribution{
        // here, the contribution_or_delta is treated as a delta
        // and therefore contribution is calculated as weight * delta
        (self.weight.0 * self.contribution_or_delta.0).into()
    }

    #[inline(always)]
    pub fn contribution_as_internal(&self) -> Contribution{
        // here, the contribution_or_delta is treated as a contribution
        // and therefore contribution is returned as is
        self.contribution_or_delta.into()
    }



    #[inline(always)]
    pub fn smoothed_contribution(storage: &[Self], shifted_idx:ShiftedIndex, cost: Contribution, coreset_star_weight: Weight) -> SmoothedContribution{
        let contribution = Self::contribution(&storage, shifted_idx);
        let weight = storage[shifted_idx.0].weight;
        let smoothed_contribution = contribution.0/cost.0 + weight.0/coreset_star_weight.0;
        SmoothedContribution(smoothed_contribution)
    }

    #[inline(always)]
    pub fn weight(&self) -> Weight{
        self.weight
    }

    #[inline(always)]
    pub fn delta(&self) -> Delta{
        // caller should ensure that this is called on a leaf node
        self.contribution_or_delta.into()
    }
}


impl <const ARITY: usize> Node<ARITY> for TreeNode<ARITY>{
    #[inline(always)]
    fn contribution(storage: &[Self], shifted_idx: ShiftedIndex) -> Contribution{
        if Self::is_leaf(storage.len(), shifted_idx) {
            storage[shifted_idx.0].contribution_as_leaf()
        } else if Self::is_internal(storage.len(), shifted_idx) {
            storage[shifted_idx.0].contribution_as_internal()
        } else {
            panic!("Invalid shifted index: {shifted_idx:?}");
        }
    }

    #[inline(always)]
    fn smoothed_contribution(storage: &[Self], shifted_idx:ShiftedIndex, cost: Contribution, coreset_star_weight: Weight) -> SmoothedContribution {
        Self::smoothed_contribution(storage, shifted_idx, cost, coreset_star_weight)
    }

    #[inline(always)]
    fn weight(&self) -> Weight {
        self.weight()
    }
    
    fn new_leaf(weight: Weight, self_affinity: SelfAffinity, min_self_affinity: SelfAffinity) -> Self {
        TreeNode{
            weight,
            contribution_or_delta: ContributionOrDelta((self_affinity.0 + min_self_affinity.0).into())
        }
    }

    fn new_internal(total_child_weight: Weight, total_child_contribution: Contribution) -> Self {
        TreeNode{
            weight: total_child_weight,
            contribution_or_delta: ContributionOrDelta(total_child_contribution.0.into())
        }
    }

    fn update_delta(storage: &mut Vec<Self>, shifted_index: ShiftedIndex, new_delta: Delta) {

        let mut shifted_node_index = shifted_index;

        

        if Self::is_leaf(storage.len(), shifted_node_index) == false {
            panic!("update_delta should only be called on leaf nodes");
        }else{
                let leaf = storage.get_mut(shifted_node_index.0).unwrap();
                if leaf.contribution_or_delta.0 <= new_delta.0{
                    return;
                }
                let delta_diff = leaf.contribution_or_delta.0 - new_delta.0;
                let contribution_diff = delta_diff*leaf.weight.0;
                leaf.contribution_or_delta = new_delta.0.into();
                let mut parent = Self::parent_idx(shifted_node_index);
                while let Ok(parent_idx) = parent{
                    let parent_node = storage.get_mut(parent_idx.0).unwrap();
                    parent_node.contribution_or_delta.0 -= contribution_diff;
                    shifted_node_index = parent_idx;
                    parent = Self::parent_idx(shifted_node_index);
                }
        }
    }

    fn from_children(storage: &mut[Self], shifted_index: ShiftedIndex) {
        let children_indices = Self::children_indices(storage, shifted_index);
        let total_weight = children_indices.iter()
            .map(|&idx| storage[idx.0].weight.0)
            .sum::<Float>();
        let total_contribution = children_indices.iter()
            .map(|&idx| Self::contribution(storage, idx))
            .sum::<Contribution>();
        storage[shifted_index.0].weight = Weight(total_weight);
        storage[shifted_index.0].contribution_or_delta = ContributionOrDelta(total_contribution.0);
    }

    #[inline(never)]
    fn _sample(
        storage: &[Self],
        rng: &mut impl rand::Rng,
        smoothed: bool,
        cost: Contribution,
        coreset_star_weight: Weight,
    ) -> Result<(ShiftedIndex, Float), Error> {
        if storage.is_empty() {
            return Err(Error::EmptyTree);
        }
    
        let mut cur = ShiftedIndex(0);
        let mut prob = Float::from(1.0);

    
        if smoothed {
            // —————— Smoothed branch ——————
            while Self::is_internal(storage.len(), cur) {
                // gather smoothed contributions from all ARITY children
                let mut total = Float::from(0.0);
                let mut contribs = [SmoothedContribution(0.0); ARITY];
                for k in 0..ARITY {
                    let ci = Self::child_idx(cur, k);
                    let ci = match ci.0 < storage.len(){
                        true => Some(ci),
                        false => None,
                    };
                    let c = ci.map(|x|Self::smoothed_contribution(&storage, x, cost, coreset_star_weight));
                    contribs[k] = c.unwrap_or(SmoothedContribution(0.0));
                    total += c.unwrap_or(SmoothedContribution(0.0)).0;
                }
    
                if total <= 0.0 {
                    return Err(Error::NumericalError);
                }
    
                // draw one random value in [0, total)
                let x = rng.random_range(0.0..total);
    
                // choose which branch
                let mut accum = Float::from(0.0);
                let mut chosen = 0;
                for k in 0..ARITY {
                    accum += contribs[k].0;
                    if x <= accum {
                        prob *= contribs[k].0 / total;
                        chosen = k;
                        break;
                    }
                }
                // descend
                cur = Self::child_idx(cur, chosen);
            }
        } else {
            // —————— Raw branch ——————
            while Self::is_internal(storage.len(), cur) {
                // gather raw contributions from all ARITY children
                let mut total = Float::from(0.0);
                let mut contribs = [Contribution(0.0); ARITY];
                for k in 0..ARITY {
                    let ci = Self::child_idx(cur, k);
                    let ci = match ci.0 < storage.len(){
                        true => Some(ci),
                        false => None,
                    };
                    let c = ci.map(|x|Self::contribution(&storage, x));
                    contribs[k] = c.unwrap_or(Contribution(0.0));
                    total += c.unwrap_or(Contribution(0.0)).0;
                }
    
                if total <= 0.0 {
                    return Err(Error::NumericalError);
                }
    
                // draw one random value in [0, total)
                let x = rng.random_range(0.0..total);
    
                // choose which branch
                let mut accum = Float::from(0.0);
                let mut chosen = 0;
                for k in 0..ARITY {
                    accum += contribs[k].0;
                    if x <= accum {
                        prob *= contribs[k].0 / total;
                        chosen = k;
                        break;
                    }
                }
                // descend
                cur = Self::child_idx(cur, chosen);
            }
        }
        Ok((cur, prob))
    }

    fn _computed_sampling_probability(storage: &[Self], smoothed: bool, shifted_idx: ShiftedIndex, cost: Contribution, coreset_star_weight: Weight) ->Result<Float,Error> {

        let mut shifted_node_index = shifted_idx;
        let mut prob: Float = 1.0;

        match smoothed{
            true =>{
                while let Ok(parent_idx) = TreeNode::<ARITY>::parent_idx(shifted_node_index){
                    let parent_contribution = Self::smoothed_contribution(&storage, parent_idx, cost, coreset_star_weight);
                    let child_contribution = Self::smoothed_contribution(&storage, shifted_node_index, cost, coreset_star_weight);
                    prob *= child_contribution.0/parent_contribution.0;
                    shifted_node_index = parent_idx;
                }
            },
            false =>{
                while let Ok(parent_idx) = TreeNode::<ARITY>::parent_idx(shifted_node_index){
                    let parent_contribution = Self::contribution(&storage, parent_idx);
                    let child_contribution = Self::contribution(&storage, shifted_node_index);
                    prob *= child_contribution.0/parent_contribution.0;
                    shifted_node_index = parent_idx;
                }
            }
        }
        Ok(prob)
    }
    
}

