use crate::common::{
    ArchivedAdj, ArchivedEdge, ArchivedNode, ArchivedOffsets, ArchivedTableType, Offsets,
    TableInfo,
};
use clap::Parser;
use half::bf16;
use itertools::izip;
use memmap2::Mmap;
use numpy::PyArray1;
use pyo3::IntoPyObjectExt;
use pyo3::PyObject;
use pyo3::PyResult;
use pyo3::Python;
use pyo3::{pyclass, pymethods};
use rand::prelude::*;
use rand::seq::SliceRandom;
use rand::seq::index;
use rkyv::rancor::Error;
use rkyv::vec::ArchivedVec;
use std::collections::{HashMap, HashSet};
use std::env::var;
use std::fs;
use std::io::{BufReader, Read};
use std::str;
use std::time::Instant;

const MAX_F2P_NBRS: usize = 5;
const MAX_FEAT_ATTN_KEYS: usize = 128;  // Maximum keys per query for sparse feat attention

struct Vecs {
    node_idxs: Vec<i32>,
    f2p_nbr_idxs: Vec<i32>,
    table_name_idxs: Vec<i32>,
    col_name_idxs: Vec<i32>,
    class_value_idxs: Vec<i32>,
    col_name_values: Vec<bf16>,
    sem_types: Vec<i32>,
    number_values: Vec<bf16>,
    text_values: Vec<bf16>,
    datetime_values: Vec<bf16>,
    boolean_values: Vec<bf16>,
    masks: Vec<bool>,
    is_targets: Vec<bool>,
    is_task_nodes: Vec<bool>,
    is_padding: Vec<bool>,
    timestamps: Vec<i32>,
    feat_attn_idx: Vec<i32>,
    feat_attn_mask: Vec<bool>,
    true_batch_size: usize,
}

struct Slices<'a> {
    node_idxs: &'a mut [i32],
    f2p_nbr_idxs: &'a mut [i32],
    table_name_idxs: &'a mut [i32],
    col_name_idxs: &'a mut [i32],
    class_value_idxs: &'a mut [i32],
    col_name_values: &'a mut [bf16],
    sem_types: &'a mut [i32],
    number_values: &'a mut [bf16],
    text_values: &'a mut [bf16],
    datetime_values: &'a mut [bf16],
    boolean_values: &'a mut [bf16],
    masks: &'a mut [bool],
    is_targets: &'a mut [bool],
    is_task_nodes: &'a mut [bool],
    is_padding: &'a mut [bool],
    timestamps: &'a mut [i32],
    feat_attn_idx: &'a mut [i32],
    feat_attn_mask: &'a mut [bool],
}

impl Vecs {
    fn new(batch_size: usize, seq_len: usize, true_batch_size: usize, d_text: usize) -> Self {
        let l = batch_size * seq_len;
        Self {
            node_idxs: vec![-1; l],
            f2p_nbr_idxs: vec![-1; l * MAX_F2P_NBRS],
            table_name_idxs: vec![0; l],
            col_name_idxs: vec![0; l],
            class_value_idxs: vec![-1; l],
            col_name_values: vec![bf16::ZERO; l * d_text],
            sem_types: vec![0; l],
            number_values: vec![bf16::ZERO; l],
            text_values: vec![bf16::ZERO; l * d_text],
            datetime_values: vec![bf16::ZERO; l],
            boolean_values: vec![bf16::ZERO; l],
            masks: vec![false; l],
            is_targets: vec![false; l],
            is_task_nodes: vec![false; l],
            is_padding: vec![true; l],
            timestamps: vec![i32::MIN; l],
            feat_attn_idx: vec![-1; l * MAX_FEAT_ATTN_KEYS],
            feat_attn_mask: vec![false; l * MAX_FEAT_ATTN_KEYS],
            true_batch_size,
        }
    }

    fn chunks_exact_mut(&mut self, seq_len: usize, d_text: usize) -> impl Iterator<Item = Slices> {
        izip!(
            self.node_idxs.chunks_exact_mut(seq_len),
            self.f2p_nbr_idxs.chunks_exact_mut(seq_len * MAX_F2P_NBRS),
            self.table_name_idxs.chunks_exact_mut(seq_len),
            self.col_name_idxs.chunks_exact_mut(seq_len),
            self.class_value_idxs.chunks_exact_mut(seq_len),
            self.col_name_values.chunks_exact_mut(seq_len * d_text),
            self.sem_types.chunks_exact_mut(seq_len),
            self.number_values.chunks_exact_mut(seq_len),
            self.text_values.chunks_exact_mut(seq_len * d_text),
            self.datetime_values.chunks_exact_mut(seq_len),
            self.boolean_values.chunks_exact_mut(seq_len),
            self.masks.chunks_exact_mut(seq_len),
            self.is_targets.chunks_exact_mut(seq_len),
            self.is_task_nodes.chunks_exact_mut(seq_len),
            self.is_padding.chunks_exact_mut(seq_len),
            self.timestamps.chunks_exact_mut(seq_len),
            self.feat_attn_idx.chunks_exact_mut(seq_len * MAX_FEAT_ATTN_KEYS),
            self.feat_attn_mask.chunks_exact_mut(seq_len * MAX_FEAT_ATTN_KEYS),
        )
        .map(
            |(
                node_idxs,
                f2p_nbr_idxs,
                table_name_idxs,
                col_name_idxs,
                class_value_idxs,
                col_name_values,
                sem_types,
                number_values,
                text_values,
                datetime_values,
                boolean_values,
                masks,
                is_targets,
                is_task_nodes,
                is_padding,
                timestamps,
                feat_attn_idx,
                feat_attn_mask,
            )| Slices {
                node_idxs,
                f2p_nbr_idxs,
                table_name_idxs,
                col_name_idxs,
                class_value_idxs,
                col_name_values,
                sem_types,
                number_values,
                text_values,
                datetime_values,
                boolean_values,
                masks,
                is_targets,
                is_task_nodes,
                is_padding,
                timestamps,
                feat_attn_idx,
                feat_attn_mask,
            },
        )
    }
    fn into_pyobject<'a>(self, py: Python<'a>) -> PyResult<Vec<PyObject>> {
        Ok(vec![
            ("node_idxs", PyArray1::from_vec(py, self.node_idxs))
                .into_py_any(py)
                .unwrap(),
            ("f2p_nbr_idxs", PyArray1::from_vec(py, self.f2p_nbr_idxs))
                .into_py_any(py)
                .unwrap(),
            (
                "table_name_idxs",
                PyArray1::from_vec(py, self.table_name_idxs),
            )
                .into_py_any(py)
                .unwrap(),
            ("col_name_idxs", PyArray1::from_vec(py, self.col_name_idxs))
                .into_py_any(py)
                .unwrap(),
            (
                "class_value_idxs",
                PyArray1::from_vec(py, self.class_value_idxs),
            )
                .into_py_any(py)
                .unwrap(),
            (
                "col_name_values",
                PyArray1::from_vec(py, self.col_name_values),
            )
                .into_py_any(py)
                .unwrap(),
            ("sem_types", PyArray1::from_vec(py, self.sem_types))
                .into_py_any(py)
                .unwrap(),
            ("number_values", PyArray1::from_vec(py, self.number_values))
                .into_py_any(py)
                .unwrap(),
            ("text_values", PyArray1::from_vec(py, self.text_values))
                .into_py_any(py)
                .unwrap(),
            (
                "datetime_values",
                PyArray1::from_vec(py, self.datetime_values),
            )
                .into_py_any(py)
                .unwrap(),
            (
                "boolean_values",
                PyArray1::from_vec(py, self.boolean_values),
            )
                .into_py_any(py)
                .unwrap(),
            ("masks", PyArray1::from_vec(py, self.masks))
                .into_py_any(py)
                .unwrap(),
            ("is_targets", PyArray1::from_vec(py, self.is_targets))
                .into_py_any(py)
                .unwrap(),
            ("is_task_nodes", PyArray1::from_vec(py, self.is_task_nodes))
                .into_py_any(py)
                .unwrap(),
            ("is_padding", PyArray1::from_vec(py, self.is_padding))
                .into_py_any(py)
                .unwrap(),
            ("timestamps", PyArray1::from_vec(py, self.timestamps))
                .into_py_any(py)
                .unwrap(),
            ("feat_attn_idx", PyArray1::from_vec(py, self.feat_attn_idx))
                .into_py_any(py)
                .unwrap(),
            ("feat_attn_mask", PyArray1::from_vec(py, self.feat_attn_mask))
                .into_py_any(py)
                .unwrap(),
            ("true_batch_size", self.true_batch_size)
                .into_py_any(py)
                .unwrap(),
        ])
    }
}

struct Dataset {
    mmap: Mmap,
    text_mmap: Mmap,
    p2f_adj_mmap: Mmap,
    offsets: Vec<i64>,
    table_info: HashMap<String, TableInfo>,
}

struct Item {
    dataset_idx: i32,
    node_idx: i32,
}

#[pyclass]
pub struct Sampler {
    batch_size: usize,
    rank: usize,
    world_size: usize,
    datasets: Vec<Dataset>,
    items: Vec<Item>,
    ctx_len: usize,            // Total context length available
    max_local_ctx_len: usize,  // Maximum cells per BFS collection
    max_bfs_width: usize,      // Maximum number of DB nodes per BFS level
    use_random_walk: bool,     // If true, use random walks to find similar nodes
    use_random_sampling: bool, // If true, use random sampling to find similar nodes
    use_connecting_nodes: bool, // If true, include connecting nodes from random walks
    num_walks: usize,          // Number of random walks to perform
    walk_length: usize,        // Maximum length of each random walk
    mask_prob: f64,               // Probability of masking cells from similar nodes
    epoch: u64,
    d_text: usize,
    seed: u64,
    target_columns: Vec<i32>,
    columns_to_drop: Vec<Vec<i32>>,
}

#[pymethods]
impl Sampler {
    #[new]
    #[allow(clippy::too_many_arguments)]
    fn new(
        dataset_tuples: Vec<(String, i32, i32)>,
        batch_size: usize,
        rank: usize,
        world_size: usize,
        ctx_len: usize,
        max_local_ctx_len: usize,
        max_bfs_width: usize,
        use_random_walk: bool,
        use_random_sampling: bool,
        use_connecting_nodes: bool,
        num_walks: usize,
        walk_length: usize,
        mask_prob: f64,
        embedding_model: &str,
        d_text: usize,
        seed: u64,
        target_columns: Vec<i32>,
        columns_to_drop: Vec<Vec<i32>>,
    ) -> Self {
        let mut datasets = Vec::new();
        let mut items = Vec::new();
        for (i, (db_name, node_idx_offset, num_nodes)) in dataset_tuples.into_iter().enumerate() {
            let pre_path = format!("{}/scratch/pre/{}", var("HOME").unwrap(), db_name);
            let nodes_path = format!("{}/nodes.rkyv", pre_path);
            let file = fs::File::open(&nodes_path).unwrap();
            let mmap = unsafe { Mmap::map(&file).unwrap() };

            let text_path = format!("{}/text_emb_{}.bin", pre_path, embedding_model);
            let text_file = fs::File::open(&text_path).unwrap();
            let text_mmap = unsafe { Mmap::map(&text_file).unwrap() };

            let offsets_path = format!("{}/offsets.rkyv", pre_path);
            let file = fs::File::open(&offsets_path).unwrap();
            let mut bytes = Vec::new();
            BufReader::new(file).read_to_end(&mut bytes).unwrap();
            let archived = rkyv::access::<ArchivedOffsets, Error>(&bytes).unwrap();
            let offsets = rkyv::deserialize::<Offsets, Error>(archived).unwrap();
            let offsets = offsets.offsets;

            let p2f_adj_path = format!("{}/p2f_adj.rkyv", pre_path);
            let p2f_adj_file = fs::File::open(&p2f_adj_path).unwrap();
            let p2f_adj_mmap = unsafe { Mmap::map(&p2f_adj_file).unwrap() };
            let target = target_columns[i];

            // Load table_info.json
            let table_info_path = format!("{}/table_info.json", pre_path);
            let table_info_file = fs::File::open(&table_info_path).unwrap();
            let table_info: HashMap<String, TableInfo> =
                serde_json::from_reader(BufReader::new(table_info_file)).unwrap();

            datasets.push(Dataset {
                mmap,
                text_mmap,
                p2f_adj_mmap,
                offsets,
                table_info,
            });

            for j in node_idx_offset..node_idx_offset + num_nodes {
                let node = get_node(&datasets[i], j);
                // skip the node if target column was removed during preprocessing
                if node.col_name_idxs.iter().any(|&c| c == target) {
                    items.push(Item {
                        dataset_idx: i as i32,
                        node_idx: j,
                    });
                }
            }
        }

        let epoch = 0;
        Self {
            batch_size,
            rank,
            world_size,
            datasets,
            items,
            ctx_len,
            max_local_ctx_len,
            max_bfs_width,
            use_random_walk,
            use_random_sampling,
            use_connecting_nodes,
            num_walks,
            walk_length,
            mask_prob,
            epoch,
            d_text,
            seed,
            target_columns,
            columns_to_drop,
        }
    }

    fn len_py(&self) -> PyResult<usize> {
        Ok(self.len())
    }

    fn batch_py<'a>(&self, py: Python<'a>, batch_idx: usize) -> PyResult<Vec<PyObject>> {
        self.batch(batch_idx).into_pyobject(py)
    }

    fn shuffle_py(&mut self, epoch: u64) {
        self.epoch = epoch;
        let mut rng = StdRng::seed_from_u64(epoch.wrapping_add(self.seed));
        self.items.shuffle(&mut rng);
    }
}

impl Sampler {
    fn len(&self) -> usize {
        self.items.len().div_ceil(self.batch_size * self.world_size)
    }

    fn batch(&self, batch_idx: usize) -> Vecs {
        let true_batch_size = self.batch_size.min(
            self.items.len()
                - self.rank * self.batch_size
                - batch_idx * self.batch_size * self.world_size,
        );

        let mut vecs = Vecs::new(self.batch_size, self.ctx_len, true_batch_size, self.d_text);

        // Parallelize batch processing across sequences
        vecs.chunks_exact_mut(self.ctx_len, self.d_text)
            .enumerate()
            .for_each(|(i, slices)| {
                let j =
                    batch_idx * self.batch_size * self.world_size + self.rank * self.batch_size + i;
                // when self.batch_size > true_batch_size, this will wrap around
                let j = j % self.items.len();
                let item = &self.items[j];
                self.seq(item, slices);
            });
        vecs
    }

    fn seq(&self, item: &Item, mut slices: Slices) {
        let dataset = &self.datasets[item.dataset_idx as usize];
        let target_column = self.target_columns[item.dataset_idx as usize];
        let columns_to_drop = &self.columns_to_drop[item.dataset_idx as usize];

        // Step 1: Get target node (the node we're predicting)
        let target_node_idx = item.node_idx;
        let target_node = get_node(dataset, target_node_idx);

        // Step 2: Find similar nodes (method depends on use_random_walk flag)
        let (mut similar_nodes, connecting_nodes_map) = if self.use_random_walk {
            // Use on-the-fly random walks to find similar nodes AND their connecting paths
            self.find_similar_nodes_via_walks(
                dataset,
                target_node_idx,
                target_node,
                self.num_walks,
                self.walk_length,
            )
        } else if self.use_random_sampling {
            // Use precomputed similar nodes (no connecting paths)
            let similar = self.find_similar_nodes(dataset, target_node_idx, target_node);
            (similar, HashMap::new())
        } else {
            (Vec::new(), HashMap::new())
        };

        // Step 4: Build sequence - target, then similar nodes with their BFS neighbors
        let mut visited = std::collections::HashSet::new();
        // Track depth at which nodes were visited across all BFS calls
        let mut visited_at_depth: HashMap<i32, usize> = HashMap::new();

        // Add target node to similar nodes list
        similar_nodes.insert(0, (target_node_idx, usize::MAX));

        // Create a HashSet to store similar nodes added
        let mut added_similar_nodes: HashSet<i32> = HashSet::new();

        // Collect all nodes first, track if they're from a similar node
        // Format: (node_idx, cell_i, col_idx)
        let mut cells_to_add: Vec<(i32, usize, i32)> = Vec::new();

        for (similar_idx, (similar_node_idx, _)) in similar_nodes.iter().enumerate() {

            added_similar_nodes.insert(*similar_node_idx);

            // Skip target node (already will be included via BFS)
            if *similar_node_idx != target_node_idx && self.use_random_walk && self.use_connecting_nodes {
                // Step 4a: Add connecting nodes if we have them
                if let Some(connecting_nodes) = connecting_nodes_map.get(similar_node_idx) {
                    for &path_node_idx in connecting_nodes {
                        if visited.contains(&path_node_idx) {
                            continue;
                        }
                        visited.insert(path_node_idx);

                        let node = get_node(dataset, path_node_idx);
                        for cell_i in 0..node.col_name_idxs.len() {
                            let col_idx: i32 = node.col_name_idxs[cell_i].into();

                            // Skip columns to drop
                            if (node.node_idx == target_node_idx && columns_to_drop.contains(&col_idx))
                                || (node.timestamp == target_node.timestamp && columns_to_drop.contains(&col_idx))
                            {
                                continue;
                            }

                            cells_to_add.push((path_node_idx, cell_i, col_idx));

                            if cells_to_add.len() == self.ctx_len {
                                break;
                            }
                        }

                        if cells_to_add.len() == self.ctx_len {
                            break;
                        }
                    }

                    if cells_to_add.len() == self.ctx_len {
                        break;
                    }
                }
            }

            // Step 4b: Add BFS neighbors around similar node
            // Create a fresh RNG for each similar node to ensure determinism across different ctx_len
            let mut rng = StdRng::seed_from_u64(
                self.epoch
                    .wrapping_add(target_node_idx as u64)
                    .wrapping_add(*similar_node_idx as u64)
                    .wrapping_add(similar_idx as u64)
                    .wrapping_add(self.seed),
            );
            let similar_bfs_nodes = self.bfs_collect_nodes(dataset, *similar_node_idx, &mut rng, self.max_local_ctx_len, &mut visited_at_depth);

            for bfs_node_idx in similar_bfs_nodes {
                if visited.contains(&bfs_node_idx) {
                    continue;
                }
                visited.insert(bfs_node_idx);

                let node = get_node(dataset, bfs_node_idx);
                for cell_i in 0..node.col_name_idxs.len() {
                    let col_idx: i32 = node.col_name_idxs[cell_i].into();

                    // Skip columns to drop
                    if (node.node_idx == target_node_idx && columns_to_drop.contains(&col_idx))
                        || (node.timestamp == target_node.timestamp && columns_to_drop.contains(&col_idx))
                    {
                        continue;
                    }

                    // Check if this node is in the similar nodes set
                    cells_to_add.push((bfs_node_idx, cell_i, col_idx));

                    if cells_to_add.len() == self.ctx_len {
                        break;
                    }
                }

                if cells_to_add.len() == self.ctx_len {
                    break;
                }
            }

            if cells_to_add.len() == self.ctx_len {
                break;
            }
        }

        // Sort by column index
        cells_to_add.sort_by_key(|&(_, _, col_idx)| col_idx);

        // Add cells to sequence in sorted order
        // Create RNG for add_single_cell (used for masking)
        let mut rng = StdRng::seed_from_u64(
            self.epoch
                .wrapping_add(target_node_idx as u64)
                .wrapping_add(self.seed),
        );
        let mut seq_i = 0;
        for (node_idx, cell_i, _col_idx) in cells_to_add.iter() {
            if seq_i >= self.ctx_len {
                break;
            }

            self.add_single_cell(
                dataset,
                *node_idx,
                *cell_i,
                target_node_idx,
                target_column,
                &mut rng,
                &mut seq_i,
                &mut slices,
                &added_similar_nodes
            );
        }

        // // Convert timestamps to contiguous integer ranks based on temporal ordering
        // // First collect (timestamp, original_index) pairs for non-padding cells
        // let mut timestamp_idx_pairs: Vec<(i32, usize)> = Vec::new();
        // for i in 0..seq_i {
        //     if !slices.is_padding[i] {
        //         timestamp_idx_pairs.push((slices.timestamps[i], i));
        //     }
        // }

        // // Sort by timestamp to establish temporal ordering
        // timestamp_idx_pairs.sort_by_key(|&(ts, _)| ts);

        // // Assign contiguous ranks: cells with same timestamp get same rank
        // let mut rank = 0;
        // let mut prev_timestamp = i32::MIN;
        // for (timestamp, idx) in timestamp_idx_pairs.iter() {
        //     if *timestamp != prev_timestamp && *timestamp != i32::MIN {
        //         if prev_timestamp != i32::MIN {
        //             rank += 1;
        //         }
        //         prev_timestamp = *timestamp;
        //     }
        //     slices.timestamps[*idx] = if *timestamp == i32::MIN { i32::MIN } else { rank };
        // }

        // // Construct feat attention indices
        // self.construct_feat_attn_indices(seq_i, &mut slices);
    }

    /// Find similar nodes using the configured method (random walks or random sampling).
    fn find_similar_nodes(
        &self,
        dataset: &Dataset,
        target_node_idx: i32,
        target_node: &ArchivedNode,
    ) -> Vec<(i32, usize)> {
        if self.use_random_sampling {
            self.find_similar_nodes_random_sampling(dataset, target_node_idx, target_node)
        } else {
            // If neither is enabled, return empty vector (no similar nodes)
            Vec::new()
        }
    }

    /// Perform random walks from source node to find similar nodes and their connecting paths.
    /// Returns (similar_nodes, connecting_nodes_map) where:
    /// - similar_nodes: Vec of (node_idx, visit_count) for nodes in the same table
    /// - connecting_nodes_map: HashMap of similar_node -> HashSet of intermediate nodes
    fn find_similar_nodes_via_walks(
        &self,
        dataset: &Dataset,
        source_idx: i32,
        source_node: &ArchivedNode,
        num_walks: usize,
        max_walk_length: usize,
    ) -> (Vec<(i32, usize)>, HashMap<i32, HashSet<i32>>) {
        let mut rng = StdRng::seed_from_u64(
            self.seed
                .wrapping_add(self.epoch)
                .wrapping_add(source_idx as u64),
        );

        let mut similar_node_visits: HashMap<i32, usize> = HashMap::new();
        let mut min_distances: HashMap<i32, usize> = HashMap::new();
        let mut connecting_paths: HashMap<i32, HashSet<i32>> = HashMap::new();

        // Perform random walks
        for _ in 0..num_walks {
            let mut current_idx = source_idx;
            let mut path_nodes = Vec::new();

            // Perform a random walk
            for distance in 0..max_walk_length {
                let current_node = get_node(dataset, current_idx);

                // Check if we found a similar node (same table, not source)
                if current_node.table_name_idx == source_node.table_name_idx
                    && current_idx != source_idx
                {
                    // Count visit
                    *similar_node_visits.entry(current_idx).or_insert(0) += 1;

                    // Check if this is a new minimum distance
                    let is_new_min = min_distances
                        .get(&current_idx)
                        .map(|&min_dist| distance < min_dist)
                        .unwrap_or(true);

                    min_distances
                        .entry(current_idx)
                        .and_modify(|min_dist| *min_dist = (*min_dist).min(distance))
                        .or_insert(distance);

                    // If new minimum distance, replace connecting nodes
                    if is_new_min {
                        let path_set = connecting_paths.entry(current_idx).or_insert_with(HashSet::new);
                        path_set.clear();
                        path_set.extend(path_nodes.iter().copied());
                    }
                }

                // Select next node randomly
                let next_idx = match self.select_random_neighbor(dataset, current_idx, source_node, &mut rng) {
                    Some(idx) => idx,
                    None => break, // Dead end
                };

                path_nodes.push(next_idx);
                current_idx = next_idx;
            }
        }

        // Sort by visit count and take top N
        let mut similar_nodes: Vec<(i32, usize)> = similar_node_visits
            .into_iter()
            .collect();
        similar_nodes.sort_by(|a, b| b.1.cmp(&a.1).then_with(|| a.0.cmp(&b.0)));

        (similar_nodes, connecting_paths)
    }

    /// Select a random valid neighbor using binary search on sorted p2f edges.
    fn select_random_neighbor(
        &self,
        dataset: &Dataset,
        current_idx: i32,
        target_node: &ArchivedNode,
        rng: &mut StdRng,
    ) -> Option<i32> {
        let current_node = get_node(dataset, current_idx);
        let p2f_edges = get_p2f_edges(dataset, current_idx);

        // Filter p2f edges by temporal constraint (edges are sorted by timestamp)
        let valid_p2f_count = if target_node.timestamp.is_some() {
            let target_ts = target_node.timestamp;
            let cutoff = p2f_edges.as_slice().partition_point(|edge| {
                edge.timestamp.is_none() || edge.timestamp <= target_ts
            });
            cutoff
        } else {
            p2f_edges.len()
        };

        let total_valid_neighbors = current_node.f2p_edges.len() + valid_p2f_count;
        if total_valid_neighbors == 0 {
            return None;
        }

        for _ in 0..10 {
            // Try up to 10 times to find a valid neighbor
            let rand_idx = rng.random_range(0..total_valid_neighbors);

            if rand_idx < current_node.f2p_edges.len() {
                return Some(current_node.f2p_edges[rand_idx].node_idx.into());
            } else {
                let p2f_idx = rand_idx - current_node.f2p_edges.len();
                let edge = &p2f_edges[p2f_idx];
                if edge.table_type == ArchivedTableType::Db || edge.table_name_idx == target_node.table_name_idx {
                    return Some(edge.node_idx.into());
                }
            }
        }

        None
    }

    /// Find similar nodes by random sampling from the same table.
    fn find_similar_nodes_random_sampling(
        &self,
        dataset: &Dataset,
        target_node_idx: i32,
        target_node: &ArchivedNode,
    ) -> Vec<(i32, usize)> {
        let mut rng = StdRng::seed_from_u64(
            self.epoch
                .wrapping_add(target_node_idx as u64)
                .wrapping_add(self.seed),
        );

        // Find which table the target node belongs to
        let mut target_table_name: Option<String> = None;
        for (key, info) in &dataset.table_info {
            let start = info.node_idx_offset;
            let end = start + info.num_nodes;
            if target_node_idx >= start && target_node_idx < end {
                // Extract table name from key (format: "table_name:Split")
                if let Some(colon_pos) = key.rfind(':') {
                    target_table_name = Some(key[..colon_pos].to_string());
                    break;
                }
            }
        }

        if target_table_name.is_none() {
            return Vec::new();
        }
        let target_table_name = target_table_name.unwrap();

        // Collect ranges of nodes from the same table, excluding Test split
        let mut node_ranges: Vec<(i32, i32)> = Vec::new();
        for (key, info) in &dataset.table_info {
            // Check if this entry is for the same table
            if let Some(colon_pos) = key.rfind(':') {
                let table_name = &key[..colon_pos];
                let split = &key[colon_pos + 1..];

                // Skip if different table or Test split
                if table_name != target_table_name || split == "Test" {
                    continue;
                }

                // Add this range
                let start = info.node_idx_offset;
                let end = start + info.num_nodes;
                node_ranges.push((start, end));
            }
        }

        // Count total available nodes
        let mut total_nodes: i32 = 0;
        for (start, end) in &node_ranges {
            total_nodes += end - start;
        }

        // Adjust for excluding target_node_idx
        if node_ranges.iter().any(|(start, end)| target_node_idx >= *start && target_node_idx < *end) {
            total_nodes -= 1;
        }

        if total_nodes <= 0 {
            return Vec::new();
        }

        // Sample up to 500 nodes
        let num_samples = std::cmp::min(500, total_nodes as usize);
        let sampled_indices = index::sample(&mut rng, total_nodes as usize, num_samples);

        // Map sampled indices to actual node_idx values
        let mut result = Vec::with_capacity(num_samples);
        for idx in sampled_indices.iter() {
            let mut cumulative = 0;
            for (start, end) in &node_ranges {
                let range_size = (end - start) as usize;
                if idx < cumulative + range_size {
                    let mut node_idx = start + (idx - cumulative) as i32;
                    // Skip target_node_idx by shifting
                    if node_idx >= target_node_idx && target_node_idx >= *start && target_node_idx < *end {
                        node_idx += 1;
                    }
                    result.push((node_idx, 1));
                    break;
                }
                cumulative += range_size;
            }
        }

        result
    }


    /// Add a single cell from a node to the sequence.
    fn add_single_cell(
        &self,
        dataset: &Dataset,
        node_idx: i32,
        cell_i: usize,
        target_node_idx: i32,
        target_column: i32,
        rng: &mut StdRng,
        seq_i: &mut usize,
        slices: &mut Slices,
        added_similar_nodes: &HashSet<i32>,
    ) {
        let node = get_node(dataset, node_idx);

        slices.node_idxs[*seq_i] = node.node_idx.into();

        assert!(node.f2p_nbr_idxs.len() <= MAX_F2P_NBRS);
        for (j, f2p_nbr_idx) in node.f2p_nbr_idxs.iter().enumerate() {
            slices.f2p_nbr_idxs[*seq_i * MAX_F2P_NBRS + j] = f2p_nbr_idx.into();
        }

        slices.table_name_idxs[*seq_i] = node.table_name_idx.into();
        slices.col_name_idxs[*seq_i] = node.col_name_idxs[cell_i].into();
        slices.class_value_idxs[*seq_i] = node.class_value_idx[cell_i].into();
        slices.col_name_values[*seq_i * self.d_text..(*seq_i + 1) * self.d_text]
            .copy_from_slice(get_text_emb(dataset, slices.col_name_idxs[*seq_i], self.d_text));

        slices.sem_types[*seq_i] = node.sem_types[cell_i].clone() as i32;
        slices.number_values[*seq_i] = bf16::from_f32(node.number_values[cell_i].into());

        let text_idx: i32 = node.text_values[cell_i].into();
        slices.text_values[*seq_i * self.d_text..(*seq_i + 1) * self.d_text]
            .copy_from_slice(get_text_emb(dataset, text_idx, self.d_text));

        slices.datetime_values[*seq_i] = bf16::from_f32(node.datetime_values[cell_i].into());
        slices.boolean_values[*seq_i] = bf16::from_f32(node.boolean_values[cell_i].into());

        slices.is_targets[*seq_i] = if node.node_idx == target_node_idx && node.col_name_idxs[cell_i] == target_column {
            true
        } else if added_similar_nodes.contains(&node.node_idx.into()) && node.col_name_idxs[cell_i] == target_column {
            rng.random::<f64>() < self.mask_prob
        } else {
            false
        };
        slices.masks[*seq_i] = slices.is_targets[*seq_i];
        
        slices.is_task_nodes[*seq_i] =
            node.is_task_node || (node.col_name_idxs[cell_i] == target_column);
        slices.is_padding[*seq_i] = false;
        slices.timestamps[*seq_i] = match node.timestamp.as_ref() {
            Some(ts) => (*ts).into(),
            None => i32::MIN,
        };

        *seq_i += 1;
    }

    /// Performs BFS to collect nodes for local context.
    fn bfs_collect_nodes(
        &self,
        dataset: &Dataset,
        start_idx: i32,
        rng: &mut StdRng,
        max_local_ctx_len: usize,
        visited_at_depth: &mut HashMap<i32, usize>,
    ) -> Vec<i32> {
        let mut result = Vec::new();

        let mut start_node = get_node(dataset, start_idx);
        let mut num_cells = 0;

        // Two frontier data structures:
        // f2p_ftr: stack of (depth, node_idx) for f2p edges
        // p2f_ftr: vector of vectors, one per depth level, for p2f edges
        let mut f2p_ftr: Vec<(usize, i32)> = Vec::new();
        let mut p2f_ftr: Vec<Vec<i32>> = vec![vec![start_idx]];

        loop {
            // Select node
            let (depth, node_idx) = if !f2p_ftr.is_empty() {
                f2p_ftr.pop().unwrap()
            } else {
                let mut depth_choices = Vec::new();
                for (i, node) in p2f_ftr.iter().enumerate() {
                    if !node.is_empty() {
                        depth_choices.push(i);
                    }
                }
                if depth_choices.is_empty() {
                    return result;
                } else {
                    let depth = depth_choices[0];
                    let r = rng.random_range(0..p2f_ftr[depth].len());
                    let l = p2f_ftr[depth].len();
                    p2f_ftr[depth].swap(r, l - 1);
                    let node_idx = p2f_ftr[depth].pop().unwrap();
                    (depth, node_idx)
                }
            };

            // Check if node was visited at a depth <= current depth
            if let Some(&prev_depth) = visited_at_depth.get(&node_idx) {
                if prev_depth <= depth {
                    continue;
                }
            }

            let node = get_node(dataset, node_idx);

            // Update number of cells collected
            num_cells += node.col_name_idxs.len();
            if num_cells >= max_local_ctx_len {
                return result;
            }

            // Record the depth at which this node was visited
            visited_at_depth.insert(node_idx, depth);

            result.push(node_idx);
            
            // Add f2p edges to f2p frontier
            for edge in node.f2p_edges.iter() {
                f2p_ftr.push((depth + 1, edge.node_idx.into()));
            }

            // Get p2f edges and process them
            let p2f_edges = get_p2f_edges(dataset, node_idx);

            // Temporary storage for db edges to be subsampled
            let mut db_p2f_ftr: Vec<i32> = Vec::new();

            // The edges are sorted by timestamp, so we can binary search to find valid ones
            let valid_edges = p2f_edges
                .as_slice()
                .partition_point(|edge| {
                    edge.timestamp.is_none()
                        || (start_node.timestamp.is_some()
                            && edge.timestamp <= start_node.timestamp)
                });

            // Filter valid edges by table constraints
            let p2f_edges = &p2f_edges.as_slice()[..valid_edges];

            for edge in p2f_edges.iter() {
                // include edges to task table only if seed node belongs to the task table
                if edge.table_name_idx != start_node.table_name_idx && edge.table_type != ArchivedTableType::Db {
                    continue;
                }

                if edge.table_type == ArchivedTableType::Db {
                    db_p2f_ftr.push(edge.node_idx.into());
                    continue;
                }

                if depth + 1 >= p2f_ftr.len() {
                    for _i in p2f_ftr.len()..=depth + 1 {
                        p2f_ftr.push(vec![]);
                    }
                }
                p2f_ftr[depth + 1].push(edge.node_idx.into());
            }

            // Subsample DB edges based on max_bfs_width
            let idxs = if db_p2f_ftr.len() > self.max_bfs_width {
                index::sample(rng, db_p2f_ftr.len(), self.max_bfs_width).into_vec()
            } else {
                (0..db_p2f_ftr.len()).collect::<Vec<_>>()
            };

            for idx in idxs.iter() {
                if depth + 1 >= p2f_ftr.len() {
                    for _i in p2f_ftr.len()..=depth + 1 {
                        p2f_ftr.push(vec![]);
                    }
                }
                p2f_ftr[depth + 1].push(db_p2f_ftr[*idx]);
            }
        }
    }

    /// Constructs sparse feat attention indices for each query position.
    /// feat attention pattern: query attends to keys in same_node OR kv_in_f2p
    fn construct_feat_attn_indices(&self, seq_len: usize, slices: &mut Slices) {
        // Build a map: node_idx -> list of cell positions in this sequence
        let mut node_to_cells: HashMap<i32, Vec<usize>> = HashMap::new();

        for cell_idx in 0..seq_len {
            if !slices.is_padding[cell_idx] {
                let node_idx = slices.node_idxs[cell_idx];
                node_to_cells.entry(node_idx).or_insert_with(Vec::new).push(cell_idx);
            }
        }

        // For each query position, collect keys efficiently
        for q_idx in 0..seq_len {
            if slices.is_padding[q_idx] {
                continue;
            }

            let q_node_idx = slices.node_idxs[q_idx];
            let q_f2p_nbrs = &slices.f2p_nbr_idxs[q_idx * MAX_F2P_NBRS..(q_idx + 1) * MAX_F2P_NBRS];

            let mut key_indices = Vec::new();

            // Add cells from same node (same_node pattern)
            if let Some(same_node_cells) = node_to_cells.get(&q_node_idx) {
                key_indices.extend_from_slice(same_node_cells);
            }

            // Add cells from f2p neighbor nodes (kv_in_f2p pattern)
            for &nbr_node_idx in q_f2p_nbrs.iter() {
                if nbr_node_idx == -1 {
                    break;  // No more valid neighbors
                }
                if let Some(nbr_cells) = node_to_cells.get(&nbr_node_idx) {
                    key_indices.extend_from_slice(nbr_cells);
                }
            }

            // Assert we don't exceed max keys
            assert!(key_indices.len() <= MAX_FEAT_ATTN_KEYS);
            let num_keys = key_indices.len();

            // Copy to slices
            for (k, &kv_idx) in key_indices.iter().take(num_keys).enumerate() {
                slices.feat_attn_idx[q_idx * MAX_FEAT_ATTN_KEYS + k] = kv_idx as i32;
                slices.feat_attn_mask[q_idx * MAX_FEAT_ATTN_KEYS + k] = true;
            }
        }
    }
}

fn get_node(dataset: &Dataset, idx: i32) -> &ArchivedNode {
    let l = dataset.offsets[idx as usize] as usize;
    let r = dataset.offsets[(idx + 1) as usize] as usize;
    let bytes = &dataset.mmap[l..r];
    // rkyv::access::<ArchivedNode, Error>(bytes).unwrap()
    unsafe { rkyv::access_unchecked::<ArchivedNode>(bytes) }
}

fn get_p2f_edges(dataset: &Dataset, idx: i32) -> &ArchivedVec<ArchivedEdge> {
    let bytes = &dataset.p2f_adj_mmap[..];
    let p2f_adj = unsafe { rkyv::access_unchecked::<ArchivedAdj>(bytes) };
    &p2f_adj.adj[idx as usize]
}

fn get_text_emb(dataset: &Dataset, idx: i32, d_text: usize) -> &[bf16] {
    let (pref, text_emb, suf) = unsafe { dataset.text_mmap.align_to::<bf16>() };
    assert!(pref.is_empty() && suf.is_empty());
    &text_emb[(idx as usize) * d_text..(idx as usize + 1) * d_text]
}

#[derive(Parser)]
pub struct Cli {
    #[arg(default_value = "rel-f1")]
    db_name: String,
    #[arg(default_value = "128")]
    batch_size: usize,
    #[arg(default_value = "1024")]
    seq_len: usize,
    #[arg(default_value = "1000")]
    num_trials: usize,
}

pub fn main(cli: Cli) {
    let tic = Instant::now();
    let sampler = Sampler::new(
        vec![(cli.db_name, 0, 10)], // dataset_tuples
        cli.batch_size,             // batch_size
        0,                          // rank
        1,                          // world_size
        cli.seq_len,                // ctx_len
        128,                        // max_local_ctx_len
        128,                        // max_bfs_width
        true,                       // use_random_walk
        false,                      // use_random_sampling
        true,                       // use_connecting_nodes
        100,                        // num_walks
        10,                         // walk_length
        0.3,                        // mask_prob
        "all-MiniLM-L12-v2",        // embedding_model
        384,                        // d_text
        0,                          // seed
        vec![-1; 1],                // target_columns
        vec![Vec::<i32>::new()],    // columns_to_drop
    );
    println!("Sampler loaded in {:?}", tic.elapsed());

    let mut sum = 0;
    let mut sum_sq = 0;
    let mut rng = rand::rng();
    for _ in 0..cli.num_trials {
        let tic = Instant::now();
        let batch_idx = rng.random_range(0..sampler.len());
        let _batch = sampler.batch(batch_idx);
        let elapsed = tic.elapsed().as_millis();
        sum += elapsed;
        sum_sq += elapsed * elapsed;
    }
    let mean = sum as f64 / cli.num_trials as f64;
    let std = (sum_sq as f64 / cli.num_trials as f64 - mean * mean).sqrt();
    println!("Mean: {} ms,\tStd: {} ms", mean, std);
}
