use faer::sparse::*;
use faer::prelude::*;
use rand::Rng;
use crate::improved::common::CoresetCrossTerm;
use crate::improved::common::DatapointWithCoresetCrossTerm;
use crate::improved::common::Error;
use crate::improved::common::SelfAffinity;
use crate::improved::common::Weight;
use crate::improved::incident_tree::IncidentTree;
use crate::improved::non_incident_tree::NonIncidentTree;
use crate::improved::common::{Index, Float};

use crate::improved::incident_tree;
use crate::improved::non_incident_tree;



#[derive(Debug)]
#[allow(unused)]
pub enum PointInTrees{
    Incident(*mut incident_tree::TreeNode),
    NonIncident(*mut non_incident_tree::TreeNode),
}


#[allow(unused)]
impl PointInTrees{
    pub fn index(&self)-> Index{
        match self{
            PointInTrees::Incident(leaf) => unsafe{(**leaf).index()},
            PointInTrees::NonIncident(leaf) => unsafe{(**leaf).index()},
        }
    }

    pub fn as_incident(&self)-> Option<*mut incident_tree::TreeNode>{
        match self{
            PointInTrees::Incident(leaf) => Some(*leaf),
            _ => None,
        }
    }
    pub fn as_non_incident(&self)-> Option<*mut non_incident_tree::TreeNode>{
        match self{
            PointInTrees::NonIncident(leaf) => Some(*leaf),
            _ => None,
        }
    }


    pub fn weight(&self)-> Weight{
        match self{
            PointInTrees::Incident(leaf) => unsafe{(**leaf).weight()},
            PointInTrees::NonIncident(leaf) => unsafe{(**leaf).weight()},
        }
    }
}

#[allow(unused)]
pub struct SamplingTree<'a>{
    incident_tree: IncidentTree,
    non_incident_tree: NonIncidentTree,
    node_pointers: Vec<PointInTrees>,
    smallest_coreset_self_affinity: Float,
    number_of_clusters: usize,
    coreset_initialization_set_weight: Weight,
    coreset_size: usize,
    rng: rand::rngs::ThreadRng,
    number_of_data_points: usize,
    kernel_matrix: SparseColMatRef<'a, usize, Float>,
    weights: MatRef::<'a, Float>,
}

#[allow(unused)]
impl <'a> SamplingTree<'a>{
    pub fn initialize(number_of_clusters: usize, coreset_size: usize, kernel_matrix: SparseColMatRef<'a, usize, Float>,weights: MatRef::<'a, Float>, graph_form: bool)-> Self{
        // assert the kernel matrix is square and the weights are of the same size.
        // also assert that the weights are a vector.
        let dim = kernel_matrix.nrows();
        assert_eq!(dim, kernel_matrix.ncols());
        assert_eq!(dim, weights.nrows());
        assert_eq!(1, weights.ncols());
        
        assert!(coreset_size < dim);
        assert!(number_of_clusters <= coreset_size);

        let number_of_data_points = dim;
    
        // We need to extract the self-affinities of the datapoints and their weights.
        // get a view of the diagonal of the kernel matrix
        let mut non_incident_tree = NonIncidentTree::new();
        // populate non_incident_tree:
        let mut node_pointers = Vec::<PointInTrees>::with_capacity(number_of_data_points);

        
        let node_pointers = match graph_form{
          false =>{
            non_incident_tree.insert_from_iterator(
                (0..number_of_data_points).map(|i|{
                    let self_affinity = *kernel_matrix.get(i,i).unwrap();
                    let weight = unsafe{*weights.get_unchecked(i,0)};
                    (Index(i), Weight(weight), SelfAffinity(self_affinity))
                })
                , number_of_data_points).into_iter()
                .map(|pointer|{
                    PointInTrees::NonIncident(pointer.as_ptr())}
                ).collect()
          },
          true =>{
            non_incident_tree.insert_from_iterator(
                (0..number_of_data_points).map(|i|{
                    let weight = unsafe{*weights.get_unchecked(i,0)};
                    // If we know K is D^{-1}AD^{-1} and W=D then K[i,i] = 1/weight[i]^2 so we don't even need to look at the kernel matrix.
                    let self_affinity = 1.0/(weight*weight);
                    (Index(i), Weight(weight), SelfAffinity(self_affinity))
                })
                , number_of_data_points).into_iter()
                .map(|pointer|{
                    PointInTrees::NonIncident(pointer.as_ptr())}
                ).collect()
          }
        };


        let incident_tree = IncidentTree::new();

        let smallest_coreset_self_affinity = 0.0;
        let number_of_clusters = number_of_clusters;
        
        SamplingTree{
            incident_tree,
            non_incident_tree,
            node_pointers,
            smallest_coreset_self_affinity,
            number_of_clusters,
            coreset_initialization_set_weight: Weight(0.0),
            coreset_size,
            rng: rand::thread_rng(),
            number_of_data_points,
            kernel_matrix,
            weights,
        }
    }

    #[allow(unused)]
    pub fn with_rng(mut self, rng: rand::rngs::ThreadRng)-> Self{
        self.rng = rng;
        self
    }

    pub fn incident_contribution(&self)-> Float{
        self.incident_tree.contribution()
    }
    pub fn non_incident_contribution(&self)-> Float{
        self.non_incident_tree.contribution(self.smallest_coreset_self_affinity)
    }

    pub fn incident_smooth_contribution(&self)-> Float{
        let cost = self.incident_contribution() + self.non_incident_contribution();
        self.incident_tree.smoothed_contribution(cost,self.coreset_initialization_set_weight.0)
    }

    pub fn non_incident_smooth_contribution(&self)-> Float{
        let cost = self.incident_contribution() + self.non_incident_contribution();
        self.non_incident_tree.smoothed_contribution(self.smallest_coreset_self_affinity, cost, self.coreset_initialization_set_weight.0)
    }

    pub fn sample(&mut self) -> Result<PointInTrees,Error>{
        let incident_contribution = self.incident_contribution();
        let non_incident_contribution = self.non_incident_contribution();
        let total_contribution = incident_contribution + non_incident_contribution;
        let sample = self.rng.gen_range(0.0..total_contribution);
        match sample{
            sample if sample < incident_contribution => {
                // sample from the incident tree.
                self.incident_tree.sample_node(&mut self.rng).map(|(leaf,_)| {
                    PointInTrees::Incident(leaf as *const _ as *mut incident_tree::TreeNode)
                })
            },
            _ => {
                // sample from the non-incident tree.
                self.non_incident_tree.sample_node(self.smallest_coreset_self_affinity, &mut self.rng).map(|(leaf,_),| {
                    PointInTrees::NonIncident(leaf as *const _ as *mut non_incident_tree::TreeNode)
                })
            }
        }
    }

    pub fn sample_smooth_with_probs(&mut self) -> Result<(PointInTrees,Float),Error>{
        // Sample a point according to the smooth probabilities.
        let incident_smooth_contribution = self.incident_smooth_contribution();
        let non_incident_smooth_contribution = self.non_incident_smooth_contribution();
        let total_smooth_contribution = incident_smooth_contribution + non_incident_smooth_contribution;
        let sample = self.rng.gen_range(0.0..total_smooth_contribution);
        match sample < incident_smooth_contribution{
            true => {
                let mut prob = incident_smooth_contribution/total_smooth_contribution;
                // sample from the incident tree.
                let (leaf, conditional_prob) = self.incident_tree.sample_node_smoothed(
                    total_smooth_contribution,
                    self.coreset_initialization_set_weight.0,
                    &mut self.rng).unwrap();
                prob *= conditional_prob;
                Ok((PointInTrees::Incident(leaf as *const _ as *mut incident_tree::TreeNode), prob))
            },
            false =>{
                let mut prob = non_incident_smooth_contribution/total_smooth_contribution;
                // sample from the non-incident tree.
                let (leaf, conditional_prob) = self.non_incident_tree.sample_node_smoothed(
                    self.smallest_coreset_self_affinity,
                    total_smooth_contribution,
                    self.coreset_initialization_set_weight.0,
                    &mut self.rng).unwrap();
                prob *= conditional_prob;
                Ok((PointInTrees::NonIncident(leaf as *const _ as *mut non_incident_tree::TreeNode), prob))
            }
        }
    }

    pub fn update_incident_tree(&mut self, node_pointer: *mut incident_tree::TreeNode, coreset_cross_term: CoresetCrossTerm) -> (){
        self.incident_tree.update_node_coreset_cross_term(node_pointer, coreset_cross_term)
    }

    pub fn delete_from_non_incident_tree(&mut self, node_pointer: *mut non_incident_tree::TreeNode) -> Result<(Weight, SelfAffinity), Error>{
        let (weight, self_affinity) = self.non_incident_tree.delete_node(node_pointer);
        Ok((weight, self_affinity))
    }


    pub fn insert_into_incident_tree(&mut self, index: Index, weight: Weight, self_affinity: SelfAffinity, coreset_cross_term: CoresetCrossTerm) -> *mut incident_tree::TreeNode{
        self.incident_tree.insert_node(DatapointWithCoresetCrossTerm{
            weight,
            self_affinity,
            coreset_cross_term
        }, index)
    }



    pub fn sample_first_point_and_update(&mut self){
        // first point is sampled uniformly at random, added to the coreset and removed from the non-incident tree.
        let first_point: Index = self.rng.gen_range(0..self.number_of_data_points).into();

        // get the self-affinity of the first point. This is the current smallest coreset self-affinity.
        self.smallest_coreset_self_affinity = *self.kernel_matrix.get(first_point.0, first_point.0).unwrap();

        // Now we need to find it's neighbouring indices and their cross terms in the kernel matrix.
        let neighbour_indices = self.kernel_matrix.row_indices_of_col(first_point.0).map(|i|Index(i));
        let neighbour_cross_terms = self.kernel_matrix.values_of_col(first_point.0)
            .iter().map(|&cross_term|CoresetCrossTerm(self.smallest_coreset_self_affinity - 2.0*cross_term));
        
        let neighbours_with_cross_terms = neighbour_indices.zip(neighbour_cross_terms)
            .filter(|(i,_)| i!=&first_point).collect::<Vec<_>>();
        // Now we have the first point, it's self-affinity and neighbours, we can update the trees as follows:
        // 1. remove the first point and it's neighbours from the non-incident tree.
        // 2. add the first point to the coreset initialization set.
        // 3. add it's neighbours to the incident tree.

        // Remove the first point from the non-incident tree and add it to the incident tree, implicitly adding it to the coreset initialization set.
        let first_point_pointer = self.node_pointers[first_point.0].as_non_incident().unwrap();
        let (weight,self_affinity) = self.delete_from_non_incident_tree(first_point_pointer).unwrap();
        let coreset_cross_term = CoresetCrossTerm(-self_affinity.0);
        let new_pointer = self.insert_into_incident_tree(first_point, weight, self_affinity, coreset_cross_term);
        // update pointer map
        self.node_pointers[first_point.0] = PointInTrees::Incident(new_pointer);

        // update the coreset initialization set weight.
        self.coreset_initialization_set_weight = weight;


        // Remove the neighbours from the non-incident tree and add them to the incident tree.
        for (neighbour_index, coreset_cross_term) in neighbours_with_cross_terms{
            let neighbour_pointer = self.node_pointers[neighbour_index.0].as_non_incident().unwrap();
            let (weight, self_affinity) = self.delete_from_non_incident_tree(neighbour_pointer).unwrap();
            let new_pointer = self.insert_into_incident_tree(neighbour_index, weight, self_affinity, coreset_cross_term);
            // update pointer map
            self.node_pointers[neighbour_index.0] = PointInTrees::Incident(new_pointer);
        }
    }


    pub fn seed_sample_first_k(&mut self){
        // println!("{:?}", "#".repeat(20));
        // println!("Before sampling anything: ");
        // println!("Incident nodes: {:?}", &self.nodes_in_incident_tree);
        // println!("Non-Incident nodes: {:?}", &self.nodes_in_non_incident_tree);
         // sample the first point and update the trees and sets:
        //add it to the coreset init set, remove it from the non-incident tree and move its neighbours to the incident tree.
        self.sample_first_point_and_update();
        // println!("{:?}", "#".repeat(20));
        // println!("Incident nodes: {:?}", &self.nodes_in_incident_tree);
        // println!("Non-Incident nodes: {:?}", &self.nodes_in_non_incident_tree);
        
        // sample the next k-1 points and update the trees and sets based on which tree points are sampled from.
        for _ in 1..self.number_of_clusters{
            let sampled_point_in_trees = self.sample().unwrap();
            let (sampled_index, sampled_self_affinity) = match sampled_point_in_trees{
                PointInTrees::Incident(leaf_tree_node) => {
                    // update it so the contribution becomes zero. That is, update it's coreset cross term
                    // to cancel out the self-affinity.
                    // This implicitly adds the point to the coreset initialization set.
                    match unsafe{&*leaf_tree_node}{
                        incident_tree::TreeNode::Leaf(leaf) => {
                            let index = leaf.index;
                            let self_affinity = leaf.self_affinity;
                            let weight = leaf.weight;
                            let coreset_cross_term = CoresetCrossTerm(-self_affinity.0);
                            self.update_incident_tree(leaf_tree_node, coreset_cross_term);
                            // update the coreset initialization set weight.
                            self.coreset_initialization_set_weight += weight;
                            (index, self_affinity)
                        },
                        _ => panic!("Expected a leaf node."),
                    }
                },
                PointInTrees::NonIncident(leaf_tree_node) => {
                    match unsafe{&*leaf_tree_node}{
                        non_incident_tree::TreeNode::Leaf(leaf) => {
                        // remove the point from the non-incident tree
                        let index = leaf.index;
                        let self_affinity = leaf.self_affinity;
                        let weight = leaf.weight;
                        let _ = self.delete_from_non_incident_tree(leaf_tree_node).unwrap();
                        // add the non-incident point to the incident tree.
                        let coreset_cross_term = CoresetCrossTerm(-self_affinity.0);
                        assert!( self_affinity.0*weight.0 + weight.0*coreset_cross_term.0 == 0.0, "Cross term does not cancel correctly.");
                        let new_leaf_node_pointer = self.insert_into_incident_tree(index,weight, self_affinity, coreset_cross_term);
                        // update the pointer map.
                        self.node_pointers[index.0] = PointInTrees::Incident(new_leaf_node_pointer);
                        // update the coreset initialization set weight.
                        self.coreset_initialization_set_weight += weight;
                        (index, self_affinity)
                        },
                        _ => panic!("Expected a leaf node."),
                    }

                },
            };

            // update the smallest coreset self-affinity.
            let index_self_affinity = sampled_self_affinity.0;

            self.smallest_coreset_self_affinity = self.smallest_coreset_self_affinity.min(index_self_affinity); 

            // Now we need to find it's neighbouring indices and their cross terms in the kernel matrix.
            let neighbour_indices = self.kernel_matrix.row_indices_of_col(sampled_index.0).map(|i|Index(i))
                .filter(|i| i!=&sampled_index) // remove the sampled point from the neighbours.
                .collect::<Vec<_>>();
            let (neighbours_in_incident_tree, neighbours_in_non_incident_tree) : (Vec<Index>,Vec<Index>) = neighbour_indices.iter()
                .partition(|i|{
                    match &self.node_pointers[i.0]{
                        PointInTrees::Incident(_) => true,
                        PointInTrees::NonIncident(_) => false,
                    }
                });
            // Remove the neighbours from the non-incident tree and add them to the incident tree.
            for neighbour_index in neighbours_in_non_incident_tree{
                let neighbour_pointer = self.node_pointers[neighbour_index.0].as_non_incident().unwrap();
                let (weight, self_affinity) = self.delete_from_non_incident_tree(neighbour_pointer).unwrap();
                // Since these points are not in the incident tree, their cross terms correspond to the self-affinity of the point we just sampled.
                let coreset_cross_term = CoresetCrossTerm(index_self_affinity - 2.0*self.kernel_matrix.get(sampled_index.0, neighbour_index.0).unwrap());
                let pointer = self.insert_into_incident_tree(neighbour_index, weight, self_affinity, coreset_cross_term);
                // update the pointer map.
                self.node_pointers[neighbour_index.0] = PointInTrees::Incident(pointer);
            }

            // Update the cross terms of the neighbours in the incident tree.
            for neighbour_index in neighbours_in_incident_tree{
                let coreset_cross_term: CoresetCrossTerm = CoresetCrossTerm(index_self_affinity - 2.0*self.kernel_matrix.get(sampled_index.0, neighbour_index.0).unwrap());
                let pointer = self.node_pointers[neighbour_index.0].as_incident().unwrap();
                self.update_incident_tree(pointer, coreset_cross_term);
            }
            // println!("{:?}", "#".repeat(20));
            // println!("After sampling {:?}", sampled_index);
            // println!("Incident nodes: {:?}", &self.nodes_in_incident_tree);
            // println!("Non-Incident nodes: {:?}", &self.nodes_in_non_incident_tree);
        }

    }

    pub fn sample_coreset(&mut self, n: usize) -> Result<(Vec<Index>,Vec<Float>), Error>{

        let mut coreset_indices = Vec::with_capacity(n);
        let mut coreset_weights = Vec::with_capacity(n);

        (0..n).for_each(|_|{
            let (leaf,prob) = self.sample_smooth_with_probs().unwrap();
            let coreset_weight = leaf.weight().0/(prob*n as Float); 
            coreset_indices.push(leaf.index());
            coreset_weights.push(coreset_weight);
        });

        Ok((coreset_indices,coreset_weights))
    }

    pub fn construct_coreset(&mut self,) -> Result<(Vec<Index>,Vec<Float>), Error>{
        self.seed_sample_first_k();
        self.sample_coreset(self.coreset_size)
    }

}





#[cfg(test)]
mod tests{
    use super::*;
    #[test]
    fn basic(){
        // create a sparse kernel matrix corresponding to the following dense matrix:
        // 1.0      0       0.4     0.8
        // 0        1.0     0.5     0
        // 0.4      0.5     1.0     0
        // 0.8      0       0       1.0 
        let kernel_matrix = SparseColMat::<usize, Float>::try_new_from_triplets(
            4,
            4,
            &[
                (0, 0, 1.0),
                (0, 2, 0.4),
                (0, 3, 0.8),
                (1, 1, 1.0),
                (1, 2, 0.5),
                (2, 2, 1.0),
                (2, 0, 0.4),
                (2, 1, 0.5),
                (3, 0, 0.8),
                (3, 3, 1.0),
            ],
        ).unwrap();

        let weights = mat![
            [1.0],
            [3.0],
            [1.0],
            [1.0]
        ];
        // display the kernel matrix nicely
        println!("{:?}", &kernel_matrix.to_dense());
        
        let mut sampling_tree = SamplingTree::initialize(2, 3, kernel_matrix.as_ref(), weights.as_ref(),false);
        sampling_tree.seed_sample_first_k();


        let (coreset_indices, coreset_weights) = sampling_tree.sample_coreset(3).unwrap();
        println!("{:?}", coreset_indices);
        println!("{:?}", coreset_weights);
    }
}
