use std::collections::HashSet;

use faer::sparse::*;
use faer::prelude::*;
use rand::Rng;
use crate::common::CoresetCrossTerm;
use crate::common::DatapointWithCoresetCrossTerm;
use crate::common::Error;
use crate::common::SelfAffinity;
use crate::common::Weight;
use crate::incident_tree::IncidentTree;
use crate::non_incident_tree::NonIncidentTree;
use crate::common::{Index, Float};

use crate::incident_tree::IncidentLeafNode;
use crate::non_incident_tree::NonIncidentLeafNode;

#[derive(Debug)]
pub enum PointInTrees<'a>{
    Incident(&'a IncidentLeafNode),
    NonIncident(&'a NonIncidentLeafNode),
}

impl PointInTrees<'_>{
    pub fn index(&self)-> Index{
        match self{
            PointInTrees::Incident(leaf) => leaf.index(),
            PointInTrees::NonIncident(leaf) => leaf.index(),
        }
    }

    pub fn weight(&self)-> Weight{
        match self{
            PointInTrees::Incident(leaf) => leaf.weight(),
            PointInTrees::NonIncident(leaf) => leaf.weight(),
        }
    }
}

#[allow(unused)]
pub struct SamplingTree<'a>{
    incident_tree: IncidentTree,
    non_incident_tree: NonIncidentTree,
    nodes_in_incident_tree: HashSet<Index>,
    nodes_in_non_incident_tree: HashSet<Index>,
    smallest_coreset_self_affinity: Float,
    number_of_clusters: usize,
    coreset_initialization_set_weight: Weight,
    coreset_size: usize,
    rng: rand::rngs::ThreadRng,
    number_of_data_points: usize,
    kernel_matrix: SparseColMatRef<'a, usize, Float>,
    weights: MatRef::<'a, Float>,
}


impl <'a> SamplingTree<'a>{
    pub fn initialize(number_of_clusters: usize, coreset_size: usize, kernel_matrix: SparseColMatRef<'a, usize, Float>,weights: MatRef::<'a, Float>)-> Self{
        // assert the kernel matrix is square and the weights are of the same size.
        // also assert that the weights are a vector.
        let dim = kernel_matrix.nrows();
        assert_eq!(dim, kernel_matrix.ncols());
        assert_eq!(dim, weights.nrows());
        assert_eq!(1, weights.ncols());
        
        assert!(coreset_size < dim);
        assert!(number_of_clusters < coreset_size);

        let number_of_data_points = dim;
    
        // We need to extract the self-affinities of the datapoints and their weights.
        // get a view of the diagonal of the kernel matrix
        let non_incident_tree = NonIncidentTree::from_kernel_and_weights(kernel_matrix, weights).unwrap();

        let incident_tree = IncidentTree::new();

        let smallest_coreset_self_affinity = 0.0;
        let number_of_clusters = number_of_clusters;
        
        // initially, all the nodes are in the non-incident tree:
        let nodes_in_incident_tree = HashSet::new();
        let nodes_in_non_incident_tree = (0..number_of_data_points).map(|i|i.into()).collect::<HashSet<Index>>();

        SamplingTree{
            incident_tree,
            non_incident_tree,
            nodes_in_incident_tree,
            nodes_in_non_incident_tree,
            smallest_coreset_self_affinity,
            number_of_clusters,
            coreset_initialization_set_weight: Weight(0.0),
            coreset_size,
            rng: rand::thread_rng(),
            number_of_data_points,
            kernel_matrix,
            weights,
        }
    }

    #[allow(unused)]
    pub fn with_rng(mut self, rng: rand::rngs::ThreadRng)-> Self{
        self.rng = rng;
        self
    }

    pub fn incident_contribution(&self)-> Float{
        self.incident_tree.contribution()
    }
    pub fn non_incident_contribution(&self)-> Float{
        self.non_incident_tree.contribution(self.smallest_coreset_self_affinity)
    }

    pub fn incident_smooth_contribution(&self)-> Float{
        let cost = self.incident_contribution() + self.non_incident_contribution();
        self.incident_tree.smoothed_contribution(cost,self.coreset_initialization_set_weight.0)
    }

    pub fn non_incident_smooth_contribution(&self)-> Float{
        let cost = self.incident_contribution() + self.non_incident_contribution();
        self.non_incident_tree.smoothed_contribution(self.smallest_coreset_self_affinity, cost, self.coreset_initialization_set_weight.0)
    }

    pub fn sample(&mut self) -> Result<PointInTrees,Error>{
        let incident_contribution = self.incident_contribution();
        let non_incident_contribution = self.non_incident_contribution();
        let total_contribution = incident_contribution + non_incident_contribution;
        if total_contribution <= 0.0
        {
            println!("Incident contribution: {:?}, Non-Incident contribution: {:?}, Total contribution: {:?}", incident_contribution, non_incident_contribution, total_contribution);
        }
        let sample = self.rng.gen_range(0.0..total_contribution);
        match sample{
            sample if sample < incident_contribution => {
                // sample from the incident tree.
                self.incident_tree.sample_node(&mut self.rng).map(|index| PointInTrees::Incident(index))
            },
            _ => {
                // sample from the non-incident tree.
                self.non_incident_tree.sample_node(self.smallest_coreset_self_affinity, &mut self.rng).map(|index| PointInTrees::NonIncident(index))
            }
        }
    }

    pub fn sample_smooth_with_probs(&mut self) -> Result<(PointInTrees,Float),Error>{
        // Sample a point according to the smooth probabilities.
        let incident_smooth_contribution = self.incident_smooth_contribution();
        let non_incident_smooth_contribution = self.non_incident_smooth_contribution();
        let total_smooth_contribution = incident_smooth_contribution + non_incident_smooth_contribution;
        let sample = self.rng.gen_range(0.0..total_smooth_contribution);
        match sample < incident_smooth_contribution{
            true => {
                let mut prob = incident_smooth_contribution/total_smooth_contribution;
                // sample from the incident tree.
                let (leaf, conditional_prob) = self.incident_tree.sample_node_smoothed(
                    total_smooth_contribution,
                    self.coreset_initialization_set_weight.0,
                    &mut self.rng).unwrap();
                prob *= conditional_prob;
                Ok((PointInTrees::Incident(leaf), prob))
            },
            false =>{
                let mut prob = non_incident_smooth_contribution/total_smooth_contribution;
                // sample from the non-incident tree.
                let (leaf, conditional_prob) = self.non_incident_tree.sample_node_smoothed(
                    self.smallest_coreset_self_affinity,
                    total_smooth_contribution,
                    self.coreset_initialization_set_weight.0,
                    &mut self.rng).unwrap();
                prob *= conditional_prob;
                Ok((PointInTrees::NonIncident(leaf), prob))
            }
        }
    }

    pub fn update_incident_tree(&mut self, index: Index, coreset_cross_term: CoresetCrossTerm) -> Result<(), Error>{
        self.incident_tree.update_node_coreset_cross_term(index, coreset_cross_term)
    }



    pub fn delete_from_non_incident_tree(&mut self, index: Index) -> Result<(Weight, SelfAffinity), Error>{
        let (weight, self_affinity) = self.non_incident_tree.delete_node(index)?;
        self.nodes_in_non_incident_tree.remove(&index);
        Ok((weight, self_affinity))
    }

    #[allow(unused)]
    pub fn delete_from_incident_tree(&mut self, index: Index) -> Result<Float, Error>{
        let contribution = self.incident_tree.delete_node(index)?;
        self.nodes_in_incident_tree.remove(&index);
        Ok(contribution)
    }

    pub fn insert_into_incident_tree(&mut self, index: Index, weight: Weight, self_affinity: SelfAffinity, coreset_cross_term: CoresetCrossTerm) -> Result<(), Error>{
        self.incident_tree.insert_node(DatapointWithCoresetCrossTerm{
            weight,
            self_affinity,
            coreset_cross_term
        }, index)?;
        self.nodes_in_incident_tree.insert(index);
        Ok(())
    }



    pub fn sample_first_point_and_update(&mut self){
        // first point is sampled uniformly at random, added to the coreset and removed from the non-incident tree.
        let first_point: Index = self.rng.gen_range(0..self.number_of_data_points).into();

        // get the self-affinity of the first point. This is the current smallest coreset self-affinity.
        self.smallest_coreset_self_affinity = *self.kernel_matrix.get(first_point.0, first_point.0).unwrap();

        // Now we need to find it's neighbouring indices and their cross terms in the kernel matrix.
        let neighbour_indices = self.kernel_matrix.row_indices_of_col(first_point.0).map(|i|Index(i));
        let neighbour_cross_terms = self.kernel_matrix.values_of_col(first_point.0)
            .iter().map(|&cross_term|CoresetCrossTerm(self.smallest_coreset_self_affinity - 2.0*cross_term));
        
        let neighbours_with_cross_terms = neighbour_indices.zip(neighbour_cross_terms)
            .filter(|(i,_)| i!=&first_point).collect::<Vec<_>>();
        // Now we have the first point, it's self-affinity and neighbours, we can update the trees as follows:
        // 1. remove the first point and it's neighbours from the non-incident tree.
        // 2. add the first point to the coreset initialization set.
        // 3. add it's neighbours to the incident tree.

        // Remove the first point from the non-incident tree and add it to the incident tree, implicitly adding it to the coreset initialization set.
        let (weight,self_affinity) = self.delete_from_non_incident_tree(first_point).unwrap();
        let coreset_cross_term = CoresetCrossTerm(-self_affinity.0);
        let _ = self.insert_into_incident_tree(first_point, weight, self_affinity, coreset_cross_term).unwrap();
        // update the coreset initialization set weight.
        self.coreset_initialization_set_weight = weight;


        // Remove the neighbours from the non-incident tree and add them to the incident tree.
        for (neighbour_index, coreset_cross_term) in neighbours_with_cross_terms{
            let (weight, self_affinity) = self.delete_from_non_incident_tree(neighbour_index).unwrap();
            let _ = self.insert_into_incident_tree(neighbour_index, weight, self_affinity, coreset_cross_term).unwrap();
        }
    }


    pub fn seed_sample_first_k(&mut self){
        // println!("{:?}", "#".repeat(20));
        // println!("Before sampling anything: ");
        // println!("Incident nodes: {:?}", &self.nodes_in_incident_tree);
        // println!("Non-Incident nodes: {:?}", &self.nodes_in_non_incident_tree);
         // sample the first point and update the trees and sets:
        //add it to the coreset init set, remove it from the non-incident tree and move its neighbours to the incident tree.
        self.sample_first_point_and_update();
        // println!("{:?}", "#".repeat(20));
        // println!("Incident nodes: {:?}", &self.nodes_in_incident_tree);
        // println!("Non-Incident nodes: {:?}", &self.nodes_in_non_incident_tree);
        
        // sample the next k-1 points and update the trees and sets based on which tree points are sampled from.
        for _ in 1..self.number_of_clusters{
            let sampled_point_in_trees = self.sample().unwrap();
            let (sampled_index, sampled_self_affinity) = match sampled_point_in_trees{
                PointInTrees::Incident(leaf) => {
                    // update it so the contribution becomes zero. That is, update it's coreset cross term
                    // to cancel out the self-affinity.
                    // This implicitly adds the point to the coreset initialization set.
                    let index = leaf.index();
                    let self_affinity = leaf.self_affinity();
                    let weight = leaf.weight();
                    let coreset_cross_term = CoresetCrossTerm(-self_affinity.0);
                    self.update_incident_tree(index, coreset_cross_term).unwrap();
                    
                    // update the coreset initialization set weight.
                    self.coreset_initialization_set_weight += weight;

                    (index, self_affinity)
                },
                PointInTrees::NonIncident(leaf) => {
                    // remove the point from the non-incident tree
                    let index = leaf.index();
                    let self_affinity = leaf.self_affinity();
                    let weight = leaf.weight();
                    let _ = self.delete_from_non_incident_tree(index).unwrap();

                    // add the non-incident point to the incident tree.
                    let coreset_cross_term = CoresetCrossTerm(-self_affinity.0);
                    assert!( self_affinity.0*weight.0 + weight.0*coreset_cross_term.0 == 0.0, "Cross term does not cancel correctly.");
                    let _ = self.insert_into_incident_tree(index,weight, self_affinity, coreset_cross_term).unwrap();
                    
                    // update the coreset initialization set weight.
                    self.coreset_initialization_set_weight += weight;

                    (index, self_affinity)
                },
            };

            // update the smallest coreset self-affinity.
            let index_self_affinity = sampled_self_affinity.0;

            self.smallest_coreset_self_affinity = self.smallest_coreset_self_affinity.min(index_self_affinity); 

            // Now we need to find it's neighbouring indices and their cross terms in the kernel matrix.
            let neighbour_indices = self.kernel_matrix.row_indices_of_col(sampled_index.0).map(|i|Index(i))
                .filter(|i| i!=&sampled_index) // remove the sampled point from the neighbours.
                .collect::<Vec<_>>();
            let (neighbours_in_incident_tree, neighbours_in_non_incident_tree) : (Vec<Index>,Vec<Index>) = neighbour_indices.iter()
                .partition(|i|self.nodes_in_incident_tree.contains(i));
            // Remove the neighbours from the non-incident tree and add them to the incident tree.
            for neighbour_index in neighbours_in_non_incident_tree{
                let (weight, self_affinity) = self.delete_from_non_incident_tree(neighbour_index).unwrap();
                // Since these points are not in the incident tree, their cross terms correspond to the self-affinity of the point we just sampled.
                let coreset_cross_term = CoresetCrossTerm(index_self_affinity - 2.0*self.kernel_matrix.get(sampled_index.0, neighbour_index.0).unwrap());
                let _ = self.insert_into_incident_tree(neighbour_index, weight, self_affinity, coreset_cross_term).unwrap();
            }

            // Update the cross terms of the neighbours in the incident tree.
            for neighbour_index in neighbours_in_incident_tree{
                let coreset_cross_term: CoresetCrossTerm = CoresetCrossTerm(index_self_affinity - 2.0*self.kernel_matrix.get(sampled_index.0, neighbour_index.0).unwrap());
                let _ = self.update_incident_tree(neighbour_index, coreset_cross_term).unwrap();
            }
            // println!("{:?}", "#".repeat(20));
            // println!("After sampling {:?}", sampled_index);
            // println!("Incident nodes: {:?}", &self.nodes_in_incident_tree);
            // println!("Non-Incident nodes: {:?}", &self.nodes_in_non_incident_tree);
        }

    }

    pub fn sample_coreset(&mut self, n: usize) -> Result<(Vec<Index>,Vec<Float>), Error>{

        let mut coreset_indices = Vec::with_capacity(n);
        let mut coreset_weights = Vec::with_capacity(n);

        (0..n).for_each(|_|{
            let (leaf,prob) = self.sample_smooth_with_probs().unwrap();
            let coreset_weight = leaf.weight().0/(prob*n as Float); 
            coreset_indices.push(leaf.index());
            coreset_weights.push(coreset_weight);
        });

        Ok((coreset_indices,coreset_weights))
    }

    pub fn construct_coreset(&mut self,) -> Result<(Vec<Index>,Vec<Float>), Error>{
        self.seed_sample_first_k();
        self.sample_coreset(self.coreset_size)
    }

}





#[cfg(test)]
mod tests{
    use super::*;
    #[test]
    fn basic(){
        // create a sparse kernel matrix corresponding to the following dense matrix:
        // 1.0      0       0.4     0.8
        // 0        1.0     0.5     0
        // 0.4      0.5     1.0     0
        // 0.8      0       0       1.0 
        let kernel_matrix = SparseColMat::<usize, Float>::try_new_from_triplets(
            4,
            4,
            &[
                (0, 0, 1.0),
                (0, 2, 0.4),
                (0, 3, 0.8),
                (1, 1, 1.0),
                (1, 2, 0.5),
                (2, 2, 1.0),
                (2, 0, 0.4),
                (2, 1, 0.5),
                (3, 0, 0.8),
                (3, 3, 1.0),
            ],
        ).unwrap();

        let weights = mat![
            [1.0],
            [1.0],
            [1.0],
            [1.0]
        ];
        // display the kernel matrix nicely
        println!("{:?}", &kernel_matrix.to_dense());
        
        let mut sampling_tree = SamplingTree::initialize(2, 3, kernel_matrix.as_ref(), weights.as_ref());
        sampling_tree.seed_sample_first_k();


        let (coreset_indices, coreset_weights) = sampling_tree.sample_coreset(3).unwrap();
        println!("{:?}", coreset_indices);
        println!("{:?}", coreset_weights);
    }
}
