// benches/random_ops_bench.rs
use criterion::{criterion_group, criterion_main, Criterion};
use rand::prelude::*;
use dynamic_csc::DynamicCoreset;
use rand::rngs::StdRng;
use std::collections::HashMap;
use dynamic_csc::Float;
use std::time::Instant;
use std::io::Write;
use std::io::BufWriter;

use jemallocator::Jemalloc;

use indicatif::ProgressBar;

#[global_allocator]
static GLOBAL: Jemalloc = Jemalloc;

pub enum Instruction{
    Insert(String, String, Float),
    Delete(String, String),
    DeleteHalf(String, String, Float),
}

pub fn save_average_distance_errors(
    average_distance_errors: Vec<(
        usize,
        Float, // average distance error
        Float, // filtered average distance error
        Float  // affinity shift
    )>){
    let file = std::fs::File::create("average_distance_errors.txt").unwrap();
    let mut writer = BufWriter::new(file);
    for (i,error, filtered_error, affinity_shift) in average_distance_errors {
        writeln!(writer, "{}, {}, {}, {}", i,error, filtered_error, affinity_shift).unwrap();
    }
}

pub fn generate_commands(seed: u64, num_nodes: usize, num_updates: usize, insert_prob: Float, full_delete_prob: f64) -> Vec<Instruction>{

    let mut rng = StdRng::seed_from_u64(seed);

    let nodes: Vec<String> = (0..num_nodes)
    .map(|i| format!("Node{}", i))
    .collect();

    // Store the edges we inserted/updated
    let mut known_edges = HashMap::<(String,String), Float>::with_capacity((num_updates as f32*1.5) as usize);

    // Store the operations we are going to perform:
    let mut operations: Vec<Instruction> = Vec::with_capacity(num_updates);

    let pb = ProgressBar::new(num_updates as u64);
    pb.set_style(indicatif::ProgressStyle::default_bar()
        .template("{spinner:.green} {bar:.green/yellow} {decimal_bytes_per_sec} {eta} [{elapsed_precise}] ").unwrap());
        // .progress_chars("##-"));

    for _ in 0..num_updates {
        pb.inc(1);
        match rng.random_range(0.0..1.0) < insert_prob{
            true => {
                // Insert an edge
                let u = nodes[rng.random_range(0..num_nodes)].clone();
                let v = nodes[rng.random_range(0..num_nodes)].clone();
                if u == v{
                    // don't insert self loops
                    continue;
                }
                let w: Float = rng.random_range(1.0..10.0);
                operations.push(Instruction::Insert(u.clone(), v.clone(), w));
                known_edges.entry((u.clone(),v.clone())).and_modify(|e| *e += w).or_insert(w);
                known_edges.entry((v.clone(),u.clone())).and_modify(|e| *e += w).or_insert(w);

            },
            false => {
                // Delete an edge
                if !known_edges.is_empty(){

                    let entry = known_edges.iter().choose(&mut rng).unwrap();
                    let u = entry.0.0.clone();
                    let v = entry.0.1.clone();
                    let w = entry.1.clone();
                    // with probability 0.5 delete the entire edge
                    if rng.random_range(0.0..1.0) < full_delete_prob{
                        operations.push(Instruction::Delete(u.clone(), v.clone()));
                        known_edges.remove(&(u.clone(),v.clone()));
                        known_edges.remove(&(v.clone(),u.clone()));
                    }else{
                        // delete half the weight
                        let new_w: Float = w/2.0;
                        operations.push(Instruction::DeleteHalf(u.clone(), v.clone(), new_w));
                        // update the known edges
                        known_edges.entry((u.clone(),v.clone())).and_modify(|e| *e -= new_w).or_insert(new_w);
                        known_edges.entry((v.clone(),u.clone())).and_modify(|e| *e -= new_w).or_insert(new_w);
                    }
                }
            }
        }
    }
    pb.finish_with_message("Done generating instructions");
    operations
}


// replace `my_crate` with your actual crate name from Cargo.toml

fn bench_random_ops_seeded(c: &mut Criterion) {
    println!("Running seeded benchmark");
    let seed = 501;
    // Set up the parameters for the benchmark
    let num_nodes = 10_000;
    let num_updates = 500_000;

    let insert_prob = 0.8;
    let full_delete_prob = 0.25;

    println!("Generating instructions");
    let commands = generate_commands(seed,num_nodes, num_updates, insert_prob, full_delete_prob);
    let coreset_size = 2048;
    let num_clusters = 50;
    let affinity_shift = 0.1;
    let num_threads = 4;
    let degree_threshold = 4.05;

    c.bench_function("random_ops_seeded", |b| {
        b.iter(|| {
            let t_0 = Instant::now();
            let mut dynamic_coreset = DynamicCoreset::new(
                coreset_size, 
                num_clusters, 
                affinity_shift, 
                degree_threshold, 
                num_threads, 
                100_000);
            let mut average_distance_errors = Vec::<(usize,Float, Float, Float)>::new();
            
            let pb = ProgressBar::new(commands.len() as u64);
            pb.set_style(indicatif::ProgressStyle::default_bar()
            .template("{spinner:.green} {bar:.green/yellow} {decimal_bytes_per_sec} {eta} [{elapsed_precise}] {msg} ").unwrap());

            commands.iter().enumerate().for_each(|(i,command)| {
                match command {
                    Instruction::Insert(u, v, w) => {
                        // Insert the edge into the coreset
                        dynamic_coreset.rust_insert_edge(u, v, *w).unwrap();
                    }
                    Instruction::Delete(u, v) => {
                        // Delete the edge from the coreset
                        dynamic_coreset.rust_delete_entire_edge(u, v).unwrap();
                    }
                    Instruction::DeleteHalf(u, v, w) => {
                        // Delete half the weight of the edge from the coreset
                        dynamic_coreset.rust_delete_edge(u, v, *w).unwrap();
                    }
                }
                let msg = format!("shift: {:.3}", dynamic_coreset.affinity_shift);
                pb.set_message(msg);
                // Update the average distance errors
                let average_dist_error_maybe = dynamic_coreset.get_average_distance_error();

                if let Some(avg_dist_error) = average_dist_error_maybe{
                    let filtered_avg_dist_error = dynamic_coreset.get_filtered_average_distance_error();
                    let affinity_shift = dynamic_coreset.affinity_shift;
                    average_distance_errors.push((i,avg_dist_error, filtered_avg_dist_error,affinity_shift));
                }
                pb.inc(1);

                if i % 10_000 == 0 && i > 0{
                    let coreset_graph = dynamic_coreset.rust_extract_coreset_graph().unwrap();

                    // choose a random labelling for the coreset_graph:
                    let mut rng = StdRng::seed_from_u64(seed);
                    // sample n labels in the range 0 to num_clusters
                    let coreset_labels = (0..coreset_graph.shape().0).map(|_| rng.random_range(0..num_clusters)).collect::<Vec<_>>();
                    let coreset_indices = &dynamic_coreset.coreset_tree[0].indices;
                    let coreset_weights = &dynamic_coreset.coreset_tree[0].weights;

                    let full_labels_and_dists = dynamic_coreset.rust_label_full_graph(
                        coreset_indices,
                        coreset_weights, 
                        &coreset_labels,
                        num_clusters
                    );

                }
            });
            pb.finish_with_message("Done processing instructions");
            let t_1 = Instant::now();
            let elapsed = t_1.duration_since(t_0);
            println!("Elapsed time: {:?}", elapsed);
            println!("Number of coresets computed: {:?}", dynamic_coreset.num_coresets_computed);

            let root_coreset = &dynamic_coreset.coreset_tree[0];
            let indices = &root_coreset.indices;
            let weights = &root_coreset.weights;
            // print first 10 indices and weights
            println!("Indices: {:?}", &indices[0..10]);
            println!("Weights: {:?}", &weights[0..10]);

            // total number of edges in the full adjacency list at the end
            let total_edges = dynamic_coreset.adjacency.iter().map(|(_,v)| v.len()).sum::<usize>();
            println!("Total number of edges in the full adjacency list at the end: {:?}", total_edges);

            let coreset_indices_set = indices.iter().cloned().collect::<std::collections::HashSet<_>>();

            let total_edges_of_coreset_indices = indices.iter().map(|i| {
                dynamic_coreset.adjacency.get(i).map(|v| {
                    // count the number of neighbours in the coreset indices
                    v.iter().filter(|(name,_)| coreset_indices_set.contains(name)).count()
                }).unwrap()
            }).sum::<usize>();
            println!("Total number of edges in the coreset graph at the end: {:?}", total_edges_of_coreset_indices);

            // Save the average distance errors to a file
            save_average_distance_errors(average_distance_errors);

        })
    });
}

// The Criterion macros: define the group of benchmarks and the main entry point
criterion_group!(benches, bench_random_ops_seeded);
criterion_main!(benches);
