mod common;
pub(crate) mod unstable;
mod sampling_tree;

use std::collections::BTreeMap;
use rustc_hash::FxHashMap;

pub (crate) use sampling_tree::SamplingTree;
use common::*;
use rand::rngs::SmallRng;
use rand::Rng;

pub use common::{Float, SelfAffinity};
use crate::NodeName;


#[derive(Debug)]
pub struct DefaultCoresetSampler<'a,T, const ARITY: usize>
    where T: Node<ARITY> + Clone{
    sampling_tree: SamplingTree<'a,T, ARITY>,
    num_clusters: usize,
    coreset_star_weight: Weight,
    coreset_size: usize,
    rng: &'a mut SmallRng,
    number_of_data_points: usize,
    adj_matrix: &'a FxHashMap<NodeName, BTreeMap<NodeName, Float>>,
    degree_vector: &'a[Float],
    names: &'a [NodeName],
    names_to_indices: &'a FxHashMap<NodeName, usize>,
    self_affinities: &'a [SelfAffinity],
    output_indices: &'a mut Vec<usize>,
    output_weights: &'a mut Vec<Float>,
    unique_indices: &'a mut Vec<NodeName>,
    unique_weights: &'a mut Vec<Float>,
    seen: &'a mut FxHashMap<usize, usize>,
    x_star_index: Index,
    numerical_warning: bool,
    num_warnings: usize,
    total_distance_updates: usize,
}

/// Assumes a undirected graph where self loops are all 1
impl <'a,T, const ARITY: usize> DefaultCoresetSampler<'a,T, ARITY>
    where T: Node<ARITY> + Clone
{

    pub fn new(
        adj_matrix: &'a FxHashMap<NodeName, BTreeMap<NodeName, Float>>,
        degree_vector: &'a[Float],
        self_affinities: &'a[SelfAffinity],
        output_indices: &'a mut Vec<usize>,
        output_weights: &'a mut Vec<Float>,
        names_to_indices: &'a FxHashMap<NodeName, usize>,
        unique_indices: &'a mut Vec<NodeName>,
        unique_weights: &'a mut Vec<Float>,
        seen: &'a mut FxHashMap<usize, usize>,
        indices: &'a [NodeName],
        weights: &'a[Float],
        num_clusters: usize,
        coreset_size: usize,
        rng: &'a mut SmallRng,
        sampling_tree_storage: &'a mut Vec<T>) -> Self{

        let n = indices.len();
        debug_assert_eq!(n, degree_vector.len());

        let mut sampling_tree = SamplingTree::<T, ARITY>::new(sampling_tree_storage);

        // Find the node with the lowest self affinity. Aka lowest value of A[i,i]/d[i]^2
        // We assume the value of A[i,i] is 1.0


        let x_star = self_affinities.iter().enumerate().min_by(|a,b| a.1.0.partial_cmp(&b.1.0).unwrap()).unwrap().0;
        let min_self_affinity = self_affinities[x_star];

        // Populate the sampling  tree with weights and self affinities
        sampling_tree.insert_from_iterator(weights.iter().zip(self_affinities.iter()).map(|(w,self_affinity)|{
            (Weight(*w),*self_affinity)
        }),
             min_self_affinity
        );

        DefaultCoresetSampler{
            sampling_tree,
            num_clusters,
            coreset_star_weight: Weight(0.0),
            coreset_size,
            rng,
            number_of_data_points: n,
            adj_matrix,
            degree_vector: degree_vector,
            names: indices,
            names_to_indices: names_to_indices,
            self_affinities,
            output_indices: output_indices,
            output_weights: output_weights,
            unique_indices: unique_indices,
            unique_weights: unique_weights,
            seen: seen,
            x_star_index: Index(x_star),
            numerical_warning: false,
            num_warnings: 0,
            total_distance_updates: 0,
        }
    }



    
    fn index_to_node_name(&self, index: Index) -> &NodeName{
        &self.names[index.0]
    }


    fn repair(&mut self, point_added: Index){
        // We implicitly add the point to the init set and update it's neighbours:
        let point_added_degree: Float = self.degree_vector[point_added.0];
        let point_added_weight: Weight = point_added_degree.into();
        let point_added_self_affinity: SelfAffinity = self.self_affinities[point_added.0];

        self.coreset_star_weight += point_added_weight;

        // set the contribution of the added point to zero:
        self.sampling_tree.update_delta(point_added, Delta(0.0));
        self.total_distance_updates += 1;
        // Now we update the neighbours of the added point:
        self.adj_matrix.get(self.index_to_node_name(point_added)).unwrap()
            .iter().filter_map(|(neighbour_name,edge_weight)|{
                if let Some(neighbour_index) = self.names_to_indices.get(neighbour_name){
                    Some((Index(*neighbour_index),edge_weight))
                }else{
                    None
                }
            })
            .for_each(|(neighbour_index,edge_weight)|{
            // If the neighbour is the added point, skip it
            if neighbour_index == self.x_star_index{
                return;
            }
            // compute the distance^2 between the added point and the neighbour:
            let neighbour_degree: Float = self.degree_vector[neighbour_index.0];
            let neighbour_self_affinity: SelfAffinity = self.self_affinities[neighbour_index.0];
            let cross_term: CoresetCrossTerm = (edge_weight/(point_added_degree*neighbour_degree)).into();
            let mut distance2 =  point_added_self_affinity.0 + neighbour_self_affinity.0 - 2.0*cross_term.0;
            // update the delta of the neighbour:
            if distance2 < 0.0{
                self.numerical_warning = true;
                distance2 = 0.0;
                self.num_warnings += 1;
            }
            self.sampling_tree.update_delta(neighbour_index, Delta(distance2));
            self.total_distance_updates += 1;
        })
    }

    
    pub fn sample_first_point(&mut self){
        self.repair(self.x_star_index);
    }

    pub fn sample_next_k(&mut self) -> Result<(),Error>{
        // Now we run k-means++ to sample the next k points (total k+1 points)
        // first we uniformly sample the first point and repair:
        let uniform_sampled_index = Index(self.rng.random_range(0..self.number_of_data_points));
        self.repair(uniform_sampled_index);
        // Now we sample the next k-1 points and repair:

        for i in 0..self.num_clusters-1{
            let mut maybe_index = self.sampling_tree.sample(&mut self.rng);

            while let Err(Error::NumericalError) = maybe_index{
                // If we fail to sample, we rebuild the tree and try again:
                println!("Numerical error detected on round {}. Rebuilding tree", i);
                panic!();
                self.sampling_tree.rebuild_from_leaves();
                maybe_index = self.sampling_tree.sample(&mut self.rng);
            }
            let index = maybe_index.unwrap();
            self.repair(index);
        }
        Ok(())
    }

    pub fn sample_rest(&mut self) -> Result<(),Error>{
        // Now we have seeded the sampling distribution, we sample the actual coreset:

        let cost = T::contribution(&self.sampling_tree.storage, 0.into());

        for _ in 0..self.coreset_size{
            let (index,prob) = self.sampling_tree.sample_smoothed(
                &mut self.rng, cost, self.coreset_star_weight)?;
            let weight = self.sampling_tree.storage.get(self.sampling_tree.get_shifted_node_index(index).0).unwrap().weight();
            self.output_indices.push(index.0);
            self.output_weights.push(weight.0/(prob*self.coreset_size as Float));
        }

        // // sort the indices and weights in ascending order of indices
        // let mut combined: Vec<_> = coreset_indices.iter_mut().zip(coreset_weights.iter_mut()).collect();
        // combined.sort_by(|a,b| a.0.cmp(&b.0));

        Ok(())
    }

    pub fn sample(&mut self) -> Result<(),Error> {
        self.sample_first_point();
        self.sample_next_k()?;
        self.sample_rest()
    }

    pub fn sample_and_deduplicate(&mut self) -> Result<(),Error> {
        self.sample()?;
        for (i, index) in self.output_indices.iter().enumerate() {
            if !self.seen.contains_key(index) {
                self.seen.insert(*index, self.unique_indices.len());
                self.unique_indices.push(self.names[*index].clone());
                self.unique_weights.push(self.output_weights[i]);
            } else {
                let j = self.seen[index].clone();
                self.unique_weights[j] += self.output_weights[i];
            }
        }
        Ok(())
    }

    pub fn total_distance_update_warnings(&self) -> usize{
        self.num_warnings
    }
    pub fn total_distance_updates(&self) -> usize{
        self.total_distance_updates
    }


}