use std::collections::HashSet;

use crate::faster::{DynamicCorest, TreeData};
use crate::faster::common::*;
use crate::CoresetError;
use crate::faster::extraction_impls::CoresetInfo;
use faer::linalg::cholesky::ldlt::update;
use faer::sparse::{
    csr_symbolic::generic::SymbolicSparseRowMat, SparseColMat, SparseRowMat, SymbolicSparseColMat};


use faer::traits::num_traits::real::Real;
use rand::{rngs::StdRng, Rng};
use rayon::prelude::*;

use itertools::izip;
use rustc_hash::{FxHashMap, FxHashSet};

// MARK: CoresetInfo struct


impl <const ARITY: usize> DynamicCorest<ARITY>
where ConstPow2<ARITY>: PowerOfTwo
{
    //MARK: Serial contribution methods
    #[inline(always)]
    fn update_and_get_delta_and_smoothing_term_from_tree_data(node_idx: ShiftedIndex, current_gen: usize, tree_data: &mut TreeData<ARITY>) -> (Delta, SmoothingTermDelta){
        let saved_gen = tree_data.generational_flags.get_mut(node_idx.0).unwrap();
        let delta_term = tree_data.deltas.get_mut(node_idx.0).unwrap();
        let smoothing_term = tree_data.smoothing_term_deltas.get_mut(node_idx.0).unwrap();

        // If the saved delta generation is not the current one, we set the delta term to zero.
        *delta_term *= Float::from((*saved_gen == current_gen) as u8).into();

        *smoothing_term *= Float::from((*saved_gen == current_gen) as u8).into();

        // Update the saved delta generation to the current one.
        *saved_gen = current_gen;
        (*delta_term, *smoothing_term)
    }


    #[inline(always)]
    fn get_delta_from_tree_data(node_idx: ShiftedIndex, current_gen: usize, tree_data: &TreeData<ARITY> ) -> Delta{
        let saved_gen = tree_data.generational_flags[node_idx];
        let delta_term = tree_data.deltas[node_idx];
        // If the saved delta generation is not the current one, we set the delta term to zero.
        delta_term * Float::from((saved_gen == current_gen) as u8)
    }

    #[inline(always)]
    fn get_smoothing_term_delta(node_idx: ShiftedIndex, current_gen: usize, tree_data: &TreeData<ARITY>) -> SmoothingTermDelta{
        let saved_gen = tree_data.generational_flags[node_idx];
        let smoothing_term = tree_data.smoothing_term_deltas[node_idx];
        // If the saved delta generation is not the current one, we set the smoothing term to zero.
        smoothing_term * Float::from((saved_gen == current_gen) as u8)
    }

    #[inline(always)]
    fn get_weighted_kernel_distance(deg_u: Float, deg_v: Float, w: EdgeWeight, coreset_info: &CoresetInfo) -> Contribution{
        // get the contribution of u with respect to v.
        // if v is being added, this is for computing the updated contribution of u.
        let deg_u_inv = Float::from(1.0)/deg_u;
        let deg_v_inv = Float::from(1.0)/deg_v;
        let deg_v_inv_squared = deg_v_inv * deg_v_inv;
        (
            deg_u_inv + coreset_info.sigma + deg_u*deg_v_inv_squared 
            + coreset_info.sigma * deg_u * deg_v_inv 
            - Float::from(2.0)* w.into_float() * deg_v_inv
        ).into()
    }

    #[inline(always)]
    pub fn contribution(&self, node_idx: ShiftedIndex, coreset_info: &CoresetInfo) -> Contribution{
        Self::contribution_from_tree_data(
            node_idx, coreset_info, &self.tree_data
        )
    }

    #[inline(always)]
    fn contribution_from_tree_data(
        node_idx: ShiftedIndex, coreset_info: &CoresetInfo, 
        tree_data: &TreeData<ARITY>) -> Contribution{
        // x_star_term must correspond to (1/deg(x^*)^2) + sigma/deg(x^*)
        // get the up to date delta term:
        let delta_f: Float = Self::get_delta_from_tree_data(node_idx, coreset_info.generation, tree_data).into_float();
        let size_f = Float::from(tree_data.sizes[node_idx] as Float_Dtype);
        let inv_v_f = tree_data.inv_volumes[node_idx].into_float();
        let vol_f = tree_data.volumes[node_idx].into_float();
        Contribution::from(coreset_info.sigma* size_f + inv_v_f + vol_f * coreset_info.x_star_term - delta_f)
    }

    #[inline(always)]
    fn contribution_ignoring_delta(&self, node_idx: ShiftedIndex, coreset_info: &CoresetInfo) -> Contribution{
        // x_star_term must correspond to (1/deg(x^*)^2) + sigma/deg(x^*)
        // get the up to date delta term:
        let size_f = Float::from(self.tree_data.sizes[node_idx] as Float_Dtype);
        let inv_v_f = self.tree_data.inv_volumes[node_idx].into_float();
        let vol_f = self.tree_data.volumes[node_idx].into_float();
        Contribution::from(coreset_info.sigma* size_f + inv_v_f + vol_f * coreset_info.x_star_term)
    }



    //MARK: Series contribution methods

    #[inline(always)]
    fn contributions_from_arrays(
        output: &mut [Float; ARITY],
        sizes:  &[usize],
        volumes: &[Volume],
        inv_volumes: &[InvVolume],
        deltas: &[Delta],
        generations: &[usize],
        info: &CoresetInfo,
    ) {
        debug_assert_eq!(sizes.len(), volumes.len());
        debug_assert_eq!(sizes.len(), inv_volumes.len());
        debug_assert_eq!(sizes.len(), deltas.len());
        debug_assert_eq!(sizes.len(), generations.len());
    

        let filled = sizes.len().min(ARITY);
        for o in &mut output[filled..] {
            *o = Float::from(0.0);
        }
    
        for (o, s, v, iv, del, generation) in izip!(
            &mut output[..filled],
            sizes,
            volumes,
            inv_volumes,
            deltas,
            generations,
        ) {
            let size_f  = Float::from(*s as Float_Dtype);
            let inv_v_f = iv.into_float();
            let vol_f   = v.into_float();
            let delta_f = del.into_float()
                        * Float::from((*generation == info.generation) as u8);
    
            let total = info.sigma.mul_add(
                size_f,
                inv_v_f + vol_f.mul_add(info.x_star_term, -delta_f),
            );
            *o = Real::max(total,Float::from(0.0));
        }
    }

    #[inline(always)]
    fn smoothed_contributions_from_arrays(
        output: &mut [Float; ARITY],
        sizes:  &[usize],
        volumes: &[Volume],
        inv_volumes: &[InvVolume],
        deltas: &[Delta],
        smoothing_term_deltas: &[SmoothingTermDelta],
        generations: &[usize],
        total_contribution: Contribution,
        seed_weight: Volume,
        total_weight: Volume,
        info: &CoresetInfo,
    ) {
        // Almost the same as contributions_from_arrays, but we with an extra smoothing fma.

        debug_assert_eq!(sizes.len(), volumes.len());
        debug_assert_eq!(sizes.len(), inv_volumes.len());
        debug_assert_eq!(sizes.len(), deltas.len());
        debug_assert_eq!(sizes.len(), generations.len());

        let filled = sizes.len().min(ARITY);
        for o in &mut output[filled..] {
            *o = Float::from(0.0);
        }

        let inv_total_contribution = Float::from(1.0) / total_contribution.into_float();
        let inv_seed  = Float::from(1.0) / seed_weight.into_float();
        let inv_total_weight = Float::from(1.0) / total_weight.into_float();

        for (o, s, v, iv, del,smoothing_term_delta, generation) in izip!(
            &mut output[..filled],
            sizes,
            volumes,
            inv_volumes,
            deltas,
            smoothing_term_deltas,
            generations,
        ) {
            let size_f  = Float::from(*s as Float_Dtype);
            let inv_v_f = iv.into_float();
            let vol_f   = v.into_float();
            let delta_f = del.into_float()
                        * Float::from((*generation == info.generation) as u8);
            let st_f = smoothing_term_delta.into_float()
                        * Float::from((*generation == info.generation) as u8);
    

            let total = info.sigma.mul_add(
                size_f,
                inv_v_f + vol_f.mul_add(info.x_star_term, -delta_f))
                .mul_add(inv_total_contribution, vol_f* inv_seed);
                // .mul_add(inv_total_contribution, vol_f* inv_total_weight + st_f);
            *o = Real::max(total, Float::from(0.0));
        }
    }

    //MARK: Repair
    pub fn repair(&mut self, point_added: NodeIdentity, coreset_info: &CoresetInfo) -> SamplingStats{
        // We implicitly add the point to the init set and update its neighbours:

        // First we set the delta of the point to its base (raw) contribution to give it a total
        // contribution of zero (simulating removal).
        let point_added_index = self.node_location_map.get(&point_added).unwrap().clone();

        // To do this, we first store the contribution including any valid delta
        let contribution = self.contribution(point_added_index, coreset_info);
        // Now get the contribution ignoring any delta.
        let raw_contribution = self.contribution_ignoring_delta(point_added_index,coreset_info).into_float();
        
        // Set the delta to cancel out the raw contribution.
        *self.tree_data.deltas.get_mut(point_added_index.0).unwrap() = Delta::from(raw_contribution);
        *self.tree_data.generational_flags.get_mut(point_added_index.0).unwrap() = coreset_info.generation;

        // Now self.contribution will return 0 for this point.
        debug_assert_eq!(self.contribution(point_added_index, coreset_info).into_float(), 0.0);

        // Next we propagate the difference in delta up the tree using the difference between the raw contribution and actual contribution
        Self::propagate_up_with_closure(&mut self.tree_data, point_added_index, |this, node_idx|{
            let up_to_date_delta = Self::update_and_get_delta_and_smoothing_term_from_tree_data( node_idx, coreset_info.generation, this).0;
            *this.deltas.get_mut(node_idx.0).unwrap() = up_to_date_delta + Delta::from(raw_contribution- contribution.into_float());
        });

        // Now we go over the neighbours of the point and update their deltas as well
        // if their contribution has decreased
        let deg_point_added = self.degrees.get(&point_added).unwrap().1.into_float();

        let neighbours = self.adjacency.get(&point_added).unwrap();
        let num_neighbours = neighbours.len();

        let mut updated_neighbours = vec![(point_added_index, deg_point_added)]; // set of points whose closest seed is the point added
        let mut weight_of_seed = deg_point_added; // weight of the points whose closest seed is the point added


        let num_distance_clippings: usize = neighbours.iter().map(|(&neighbour,&weight)|{
        let neighbour_index = self.node_location_map.get(&neighbour).unwrap().clone();

            let current_contribution = Self::contribution_from_tree_data(
                neighbour_index, coreset_info,
                &self.tree_data
            );

            let deg_neighbour = self.degrees.get(&neighbour).unwrap().1.into_float();
            let contribution_to_added_point = Self::get_weighted_kernel_distance(
                deg_neighbour, deg_point_added, weight, coreset_info);
            
            let clipped_contribution_to_added_point = contribution_to_added_point.max(
                Contribution::from(Float::from(0.0))
            );

            let contribution_diff = current_contribution - clipped_contribution_to_added_point;
            
            if current_contribution > clipped_contribution_to_added_point{
                // The distance from the neighbour to the added point is smaller than to any other seed.
                updated_neighbours.push((neighbour_index, deg_neighbour));
                weight_of_seed += deg_neighbour;

                // make sure the delta is up to date, then increase it by the contribution difference:
                let _ = Self::update_and_get_delta_and_smoothing_term_from_tree_data(
                    neighbour_index, coreset_info.generation,
                    &mut self.tree_data);   
                self.tree_data.deltas[neighbour_index] += convert(contribution_diff);
                
                // assert the contribution is now the clipped contribution:
                debug_assert_eq!(
                    Self::contribution_from_tree_data(
                        neighbour_index, coreset_info, &self.tree_data
                    ).into_float(),
                    clipped_contribution_to_added_point.into_float()
                );

                // Now propagate the difference up the tree:
                Self::propagate_up_with_closure(&mut self.tree_data, neighbour_index, |this, node_idx|{
                    let up_to_date_delta = Self::update_and_get_delta_and_smoothing_term_from_tree_data( node_idx, coreset_info.generation, this).0;
                    *this.deltas.get_mut(node_idx.0).unwrap() = up_to_date_delta + convert(contribution_diff);
                });
            }

            // we keep score of how many kernel distances got clipped to zero
            if contribution_to_added_point <= Contribution::from(Float::from(0.0)){
                return 1;
            }else{
                return 0;
            }
        }).sum();

        // Now we have the set of points whose closest seed is now the point added, along with the sum of their weights.
        let total_weight = self.tree_data.volumes[0].clone();
        // We now need to update the smoothing term deltas and bubble up the differences
        for (neighbour_index, weight) in updated_neighbours.iter(){
            let (_, existing_smoothing_term) = Self::update_and_get_delta_and_smoothing_term_from_tree_data(
                *neighbour_index, coreset_info.generation, &mut self.tree_data
            );

            let base_term = weight/ total_weight.into_float(); // the smoothing term for just x^*
            let new = weight / weight_of_seed; // the up to date smoothing term

            let new_smoothing_delta = SmoothingTermDelta::new(new - base_term); // The difference between the new and base smoothing term. This is what we actuall want to store.
            debug_assert!(new_smoothing_delta.into_float() >= Float::from(0.0));

            let diff_to_add = new_smoothing_delta - existing_smoothing_term; // What we need to propagate up the tree to reflect the new smoothing term delta.
            *self.tree_data.smoothing_term_deltas.get_mut(neighbour_index.0).unwrap() = new_smoothing_delta;

            // Now propagate the difference up the tree:
            Self::propagate_up_with_closure(&mut self.tree_data, *neighbour_index, |this, node_idx|{
                let (_, up_to_date_delta) = Self::update_and_get_delta_and_smoothing_term_from_tree_data( node_idx, coreset_info.generation, this);
                *this.smoothing_term_deltas.get_mut(node_idx.0).unwrap() = up_to_date_delta + diff_to_add;
            });
        }

        return SamplingStats{
            num_samples: num_neighbours,
            num_clippings: num_distance_clippings,
        }
    }

    pub fn sample_first_point(&mut self,x_star: NodeIdentity, coreset_info: &CoresetInfo) -> SamplingStats{
        // We deterministically sample the node with highest degree:
        self.repair(x_star, coreset_info)
    }

    pub fn compute_x_star_term(sigma: Float, x_star_deg: NodeDegree) -> Float{
        let inv = x_star_deg.inv().into_float();
        (inv * inv) + (sigma * inv)
    }


   


    // MARK: Sampling
    #[inline(always)]
    pub fn sample(
        &mut self, 
        coreset_info: &CoresetInfo, 
        init_weight: NodeDegree, 
        total_cost: Contribution,
        rng: &mut StdRng) -> Result<(NodeIdentity,Float),DynamicCoresetError>{
        
        // edge cases
        if self.num_nodes_in_tree() == 0{
            return Err(DynamicCoresetError::NoData)
        }

        let mut curr = ShiftedIndex(0);
        let mut prob = 1.0f64;
        let mut buffer = [Float::from(0.0); ARITY];
        let mut cdf_buffer = [Float::from(0.0); ARITY];

        while self.tree_data.sizes[curr] > 1{
            // curr corresponds to an internal node.
            let first_child_idx = self.child_index(curr, 0).unwrap();
            let last_child_idx = self.child_index(curr, ARITY-1)
                .unwrap_or(ShiftedIndex(self.num_nodes_in_tree()-1));
            let sizes = &self.tree_data.sizes[first_child_idx.0..=last_child_idx.0];
            let volumes = &self.tree_data.volumes[first_child_idx.0..=last_child_idx.0];
            let inv_volumes= &self.tree_data.inv_volumes[first_child_idx.0..=last_child_idx.0];
            let deltas = &self.tree_data.deltas[first_child_idx.0..=last_child_idx.0];
            let generational_flags = &self.tree_data.generational_flags[first_child_idx.0..=last_child_idx.0];

            // Compute the contributions for the children
            Self::contributions_from_arrays(
                &mut buffer,
                sizes, volumes, inv_volumes, deltas, generational_flags, coreset_info
            );
            // now sample a child proportional to their contribution:
            let child_contribution_sum: Float = buffer.iter().sum();
            if child_contribution_sum <= Float::from(0.0) {
                println!("index: {curr:?}, {:?}", (curr.0 as f64).log(ARITY as f64) as usize);
                panic!("numerical instability: child contribution sum is zero or negative: {:?}", buffer);
            }
            let sample = Float::from(rng.random_range(0.0 as Float_Dtype ..child_contribution_sum.0));
            // convert the buff to a cumulative distribution function:
            cdf_buffer.copy_from_slice(&buffer);

            for i in 1..ARITY {
                cdf_buffer[i] += cdf_buffer[i-1];
            }
            // Now we sample a child:
            let child_index = cdf_buffer.iter().position(|&x| x >= sample)
                .ok_or(DynamicCoresetError::NoData)?;

            // update prob and curr:
            prob *= (buffer[child_index]/child_contribution_sum).into_inner() as f64;
            curr = self.child_index(curr, child_index).unwrap();
        };

        let node_id = self.node_location_map_reverse.get(&curr).unwrap().clone();

        Ok((node_id,Float::from(prob as Float_Dtype)))
    }

    #[inline(always)]
    pub fn sample_smoothed(
        &mut self,
        coreset_info: &CoresetInfo,
        init_weight: NodeDegree,
        total_weight: NodeDegree,
        total_cost: Contribution,
        rng: &mut StdRng) -> Result<(NodeIdentity,Float),DynamicCoresetError>{
        
        // edge cases
        if self.num_nodes_in_tree() == 0{
            return Err(DynamicCoresetError::NoData)
        }

        let mut curr = ShiftedIndex(0);
        let mut prob = 1.0f64;
        let mut buffer = [Float::from(0.0); ARITY];
        let mut cdf_buffer = [Float::from(0.0); ARITY];

        while self.tree_data.sizes[curr] > 1{
            // curr corresponds to an internal node.
            let first_child_idx = self.child_index(curr, 0).unwrap();
            let last_child_idx = self.child_index(curr, ARITY-1)
                .unwrap_or(ShiftedIndex(self.num_nodes_in_tree()-1));
            let sizes = &self.tree_data.sizes[first_child_idx.0..=last_child_idx.0];
            let volumes = &self.tree_data.volumes[first_child_idx.0..=last_child_idx.0];
            let inv_volumes= &self.tree_data.inv_volumes[first_child_idx.0..=last_child_idx.0];
            let deltas = &self.tree_data.deltas[first_child_idx.0..=last_child_idx.0];
            let smoothing_term_deltas = &self.tree_data.smoothing_term_deltas[first_child_idx.0..=last_child_idx.0];
            let generational_flags = &self.tree_data.generational_flags[first_child_idx.0..=last_child_idx.0];

            // Compute the contributions for the children
            Self::smoothed_contributions_from_arrays(
                &mut buffer,
                sizes, volumes, inv_volumes, deltas, smoothing_term_deltas, generational_flags, 
                total_cost, convert(init_weight), convert(total_weight),
                coreset_info
            );
            // now sample a child proportional to their contribution:
            let child_contribution_sum: Float = buffer.iter().sum();
            if child_contribution_sum <= Float::from(0.0) {
                panic!("numerical instability: child contribution sum is zero or negative");
            }
            let sample = Float::from(rng.random_range(0.0 as Float_Dtype ..child_contribution_sum.0));
            // convert the buff to a cumulative distribution function:
            cdf_buffer.copy_from_slice(&buffer);

            for i in 1..ARITY {
                cdf_buffer[i] += cdf_buffer[i-1];
            }
            // Now we sample a child:
            let child_index = cdf_buffer.iter().position(|&x| x >= sample)
                .ok_or(DynamicCoresetError::NoData)?;

            // update prob and curr:
            prob *= (buffer[child_index]/child_contribution_sum).into_inner() as f64;
            curr = self.child_index(curr, child_index).unwrap();
        };

        let node_id = self.node_location_map_reverse.get(&curr).unwrap().clone();

        Ok((node_id,Float::from(prob as Float_Dtype)))
    }

    
}