#include "partitioning.h"
#include <cfloat>
#include <utility>
#include <array>
#include <chrono>
#include <random>
#include <algorithm>
#include <iterator>
#include <omp.h>
#include <cmath>
#include <cstdint>
#include <numeric>

// random pick among iterator
// https://stackoverflow.com/questions/6942273/how-to-get-a-random-element-from-a-c-container
template<typename Iter, typename RandomGenerator>
Iter select_randomly(Iter start, Iter end, RandomGenerator& g) {
    std::uniform_int_distribution<> dis(0, std::distance(start, end) - 1);
    std::advance(start, dis(g));
    return start;
}

template<typename Iter>
Iter select_randomly(Iter start, Iter end) {
    static std::random_device rd;
    static std::mt19937 gen(rd());
    return select_randomly(start, end, gen);
}

// Find the maximum element(s) in an array
// https://stackoverflow.com/questions/61350093/finding-multiple-max-elements-in-a-vector-c

template<typename T>
std::pair<T, std::vector<std::size_t>> find_max_elements(std::vector<T> const& v) {
    std::vector<std::size_t> indices;
    double current_max = -DBL_MAX;
    for (std::size_t i = 0; i < v.size(); ++i) {
        if (v[i] > current_max) {
            current_max = v[i];
            indices.clear();
        }
        if (v[i] == current_max) {
            indices.push_back(i);
        }
    }
    return std::make_pair(current_max, indices);
}

// sorting function which returns sorted index
template <typename T>
std::vector<int16_t> sort_indexes(const std::vector<T> &v) {
  // initialize original index locations
  std::vector<int16_t> idx(v.size());
  std::iota(idx.begin(), idx.end(), 0);
  // sort indexes based on comparing values in v
  // using std::stable_sort instead of std::sort
  // to avoid unecessary index re-orderings
  // when v contains elements of equal values
  stable_sort(idx.begin(), idx.end(),
       [&v](int16_t i1, int16_t i2) {return v[i1] < v[i2];});
  return idx;
}

// If there are multiple maximum elements, return two maximum elements.
// If there is only one maximum element, return the maximum element and the second maximum element.
std::pair<int16_t, int16_t> grinnder_scoring(std::vector<double> & v, int16_t my_part) {
    std::pair<double, std::vector<std::size_t>> max_score_n_indices = find_max_elements(v);
    // if there are multiple
    double max_score = max_score_n_indices.first;
    std::vector<std::size_t> max_indices = max_score_n_indices.second;
    // else
    if (std::find(max_indices.begin(), max_indices.end(), my_part) \
                == max_indices.end()) {
        // instead of spinner
        // we need to keep the 2nd preferential partition
        
        // 1) if len(max_indices) > 1, just take it
        if (max_indices.size() > 1) {
            return std::make_pair(max_indices[0], max_indices[1]);
        } else { // 2) if len(max_indices) == 1, we need to find the second max
            int cur_max_idx = max_indices[0];
            v[cur_max_idx] = std::numeric_limits<double>::lowest();
            std::pair<double, std::vector<std::size_t>> second_max_score_n_indices = find_max_elements(v);
            std::vector<std::size_t> second_max_indices = second_max_score_n_indices.second;
            return std::make_pair(cur_max_idx, second_max_indices[0]);
        }
    } else {
        // if max is my_part return my_part and second max to -1
        return std::make_pair(my_part, -1);
    }
}

std::pair<int16_t, std::vector<int64_t>> grinnder_priority(std::vector<int64_t> const& next_prefer, int16_t num_parts) {
    int frequency[num_parts] {0};
    for (int i = 0; i < next_prefer.size(); i++) {
        next_prefer[i];
        frequency[next_prefer[i]]++;
    }
    int16_t priority_part = std::distance(frequency, std::max_element(frequency, frequency+num_parts));
    std::vector<int64_t> priority_indices;
    for (int i = 0; i < next_prefer.size(); i++) {
        if (next_prefer[i] == priority_part) {
            priority_indices.push_back(i);
        }
    }
    return std::make_pair(priority_part, priority_indices);
}


// Computes the "reusability" for each vertex in parallel using OpenMP.
// Reusability is defined as the number of neighbors (from the CSR col array)
// that are not in the same partition as the vertex.
std::vector<int64_t> compute_reusability_omp(
    const std::vector<int64_t>& rowptr,       // size: num_vertices + 1
    const std::vector<int64_t>& col,          // neighbor list in CSR format
    const std::vector<int16_t>& partition) {  // partition per vertex (length: num_vertices)

    size_t num_vertices = partition.size();
    std::vector<int64_t> unique_reusability(num_vertices, 0);

    // Parallel loop over vertices.
    #pragma omp parallel for schedule(dynamic)
    for (size_t i = 0; i < num_vertices; i++) {
        // Use a set to track distinct partitions from the neighbors.
        std::unordered_set<int16_t> unique_parts;
        // Iterate over the neighbors of vertex i.
        for (size_t j = rowptr[i]; j < rowptr[i + 1]; j++) {
            int16_t neighbor_part = partition[col[j]];
            // Only add if the neighbor's partition is different from the vertex's own partition.
            if (neighbor_part != partition[i]) {
                unique_parts.insert(neighbor_part);
            }
        }
        // The number of distinct neighbor partitions is the size of the set.
        unique_reusability[i] = unique_parts.size();
    }
    return unique_reusability;
}



// Computes the average reusability across all vertices using OpenMP for parallel reduction.
double compute_average_reusability_omp(
    const std::vector<int64_t>& rowptr,
    const std::vector<int64_t>& col,
    const std::vector<int16_t>& partition) {

    // First, compute the per-vertex reusability.
    std::vector<int64_t> reusability = compute_reusability_omp(rowptr, col, partition);

    // Sum all reusability values in parallel using reduction.
    int64_t total = 0;
    #pragma omp parallel for reduction(+:total)
    for (size_t i = 0; i < reusability.size(); i++) {
        total += reusability[i];
    }

    // Calculate and return the average.
    return static_cast<double>(total) / reusability.size();
}

torch::Tensor
spinner_cpu(torch::Tensor rowptr, torch::Tensor col,
            int num_parts, float capacity, float beta, int max_iter,
            float halting_eps, int halting_window,
            bool log, int num_threads) {

    AT_ASSERTM(!rowptr.is_cuda(), "Rowptr tensor must be a CPU tensor");
    AT_ASSERTM(!col.is_cuda(), "Col tensor must be a CPU tensor");
    AT_ASSERTM(num_parts > 0, "num_parts must be greater than 0");

    int64_t max_C = (int64_t) (capacity * col.size(0) / num_parts);
    // std::array<int64_t, num_parts> remaining_C {0};
    int64_t remaining_C[num_parts] {0};

    AT_ASSERTM(num_parts <= 32768, "num_parts must be less than 32768");
    
    /* Random Init of Partitions */
    std::random_device rd;
    std::mt19937 mersenne_engine{rd()};
    std::uniform_int_distribution<int16_t> dist{0, num_parts-1};
    auto gen = [&dist, &mersenne_engine](){
        return dist(mersenne_engine);
    };
    // std::vector<int16_t> labels(rowptr.size(0)-1); // int16
    int16_t* labels = new int16_t[rowptr.size(0)-1]; // int16
    std::generate(labels, labels + (rowptr.size(0)-1), gen);
    // generate(std::begin(labels), std::end(labels), gen);

    if (log) {
        std::cout << "|| >>>> Sync Spinner <<<<" << std::endl;
        std::cout << "|| Initial Partition: ";
        for (int i = 0; i < 3; i++) {
            std::cout << (int)labels[i] << " ";
        }
        std::cout << std::endl;
    }

    auto col_data = col.data_ptr<int64_t>();

    auto labeled_col = torch::empty(col.size(0), col.options());
    auto labeled_col_data = labeled_col.data_ptr<int64_t>();

    // fill col with labels for fast lookup
    // #pragma omp parallel for
    // for (int64_t i = 0; i < col.size(0); i++) {
    //     labeled_col[i] = labels[col_data[i]];
    // }

    double load_per_partition[num_threads][num_parts] {0.0, 0.0};
    
    double agg_load_per_partition[num_parts] {0.0};
    double score_per_partition[num_parts] {0.0};
    double prev_score_per_partition[num_parts] {0.0};
    
    // outdated: commented due to parallelization
    double penalty_term[num_parts] {0.0};
    
    // interested vertices to migrate
    std::vector<std::vector<int64_t>> interested(num_parts);
    std::vector<double> interested_sum(num_parts, 0.0);

    bool halt = false;

    int iter = 0;
    int window = 0;

    double prev_score = 0.0;

    while (iter < max_iter && not halt) {
        // chrono time calc.
        std::chrono::system_clock::time_point start = std::chrono::system_clock::now();
        
        // initialize
        #pragma omp parallel for
        for (int i = 0; i < num_parts; i++) {
            agg_load_per_partition[i] = 0.0;
            penalty_term[i] = 0.0;
            score_per_partition[i] = 0.0;
            interested[i].clear();
            interested_sum[i] = 0.0;
            for (int j = 0; j < num_threads; j++) {
                load_per_partition[j][i] = 0.0;
            }
        }
        #pragma omp parallel for
        for (int64_t i = 0; i < col.size(0); i++) {
            labeled_col[i] = labels[col_data[i]];
        }

        // calculate the load per partition
        #pragma omp parallel for num_threads(num_threads)
        for (int64_t i = 0; i < rowptr.size(0)-1; i++) {
            load_per_partition[omp_get_thread_num()][labels[i]] \
            += rowptr[i+1].item<double>() - rowptr[i].item<double>();
        }

        // aggregate the load per partition
        #pragma omp parallel for
        for (int i = 0; i < num_parts; i++) {
            for (int j = 0; j < num_threads; j++) {
                agg_load_per_partition[i] += load_per_partition[j][i];
            }
        }

        // calculate the penalty term
        #pragma omp parallel for
        for (int i = 0; i < num_parts; i++) {
            penalty_term[i] = agg_load_per_partition[i] / max_C;
            remaining_C[i]  = max_C - agg_load_per_partition[i];
        }

        if (log) {
            std::cout << "|| Iter: " << iter << std::endl;
            std::cout << "|| p (max_load / (|E|/k)): " << \
            *std::max_element(agg_load_per_partition, agg_load_per_partition+num_parts) / (col.size(0) / num_parts) \
            << std::endl;
            // std::cout << "|| Penalty Term: ";
            // for (auto penalty : penalty_term) {
            //     std::cout << penalty << " ";
            // }
            // std::cout << std::endl;
        }

        // calculate the scores of vertices
        // then generate interested vertices which want to migrate

        // this version is synchronous version
        for (int64_t i = 0; i < rowptr.size(0)-1; i++) {
            int64_t num_edges = rowptr[i+1].item<int64_t>() - rowptr[i].item<int64_t>();
            if (num_edges == 0) {
                continue; // zero skipping
            }
            int64_t my_part = labels[i];
            std::vector<double> cur_v_scores(num_parts, 0.0); // temporal score vector
            for (int64_t j = rowptr[i].item<int64_t>(); j < rowptr[i+1].item<int64_t>(); j++) {
                int neighbor_part = labeled_col[j].item<int>();
                cur_v_scores[neighbor_part] += 1.0;
            }
            for (int j = 0; j < num_parts; j++) {
                cur_v_scores[j] /= num_edges;
                cur_v_scores[j] = beta + cur_v_scores[j] - beta * penalty_term[j];
            }
            score_per_partition[my_part] += cur_v_scores[my_part];

            std::pair<double, std::vector<std::size_t>> max_score_n_indices = find_max_elements(cur_v_scores);
            double max_score = max_score_n_indices.first;
            std::vector<std::size_t> max_indices = max_score_n_indices.second;
            if (std::find(max_indices.begin(), max_indices.end(), my_part) \
                        == max_indices.end()) {
                // if my_part is not the maximum
                // select one from the maximum(s)
                int random_idx = *select_randomly(max_indices.begin(), max_indices.end());
                interested[random_idx].push_back(i);
                interested_sum[random_idx] += num_edges;
            }
        }

        // intersted vertices are migrated according to the remaining capacity
        // of the target partitions

        // 1) cacluate the migration probability
        double migration_prob[num_parts] {0.0};
        for (int i = 0; i < num_parts; i++) {
            if (interested[i].size() == 0 || remaining_C[i] <= 0) {
                // we do not migrate if there is no interested vertices
                // or there is no remaining capacity
                migration_prob[i] = 0.0;
            } else if (remaining_C[i] >= interested_sum[i]) {
                migration_prob[i] = 1.0; // we migrate all interested vertices
            } else { 
                // we migrate vertices probabilistically
                migration_prob[i] = remaining_C[i] / interested_sum[i];
            }
        }

        // 2) migrate vertices with the calculated probability
        for (int i = 0; i < num_parts; i++) {
            auto rd = std::random_device {}; 
            auto rng = std::default_random_engine { rd() };
            // shuffle and just pop the first n elements (n = num_migrate)
            std::shuffle(interested[i].begin(), interested[i].end(), rng);
            int64_t num_migrate = (int64_t) (interested[i].size() * migration_prob[i]);
            std::cout << "Remaining... " << remaining_C[i] << ",  ";
            std::cout << "Migrate... " << num_migrate << " / ";
            for (int64_t k = 0; k < num_migrate; k++) {
                if (interested[i].size() == 0 || k >= interested[i].size()) {
                    break;
                }
                labels[interested[i][k]] = i;
                // todo - we need to keep track of the migration
                // to minimize the traversing overhead
            }
        }
        std::cout << std::endl;

        // calculate the score
        double cur_score = 0.0;
        for (int i = 0; i < num_parts; i++) {
            cur_score += score_per_partition[i];
        }

        // calculate the step
        double step = std::abs(1 - cur_score/prev_score);
        if (log) {
            std::cout << "|| Step: " << step << ", Cur Score: " << cur_score \
            << ", Prev Score: " << prev_score << std::endl;
        }
        // check the halting condition
        if (step < halting_eps) {
            window++;
            if (window >= halting_window) {
                halt = true;
            }
        } else {
            window = 0;
        }

        prev_score = cur_score;
        iter++;

        std::chrono::duration<double> iter_time = std::chrono::system_clock::now() - start;

        if (log) {
            std::cout << "|| ============== Time: " << iter_time.count() << " (sec) " << "==============" << std::endl;
        }

        // todo - only focus on change!!! (must be optimized)
    }

    // auto options = torch::TensorOptions().dtype(torch::kInt16);
    // return torch::from_blob(labels.data(), {(int64_t)labels.size()}, options).clone();
    return torch::from_blob(labels, {rowptr.size(0)-1}, torch::kInt16);
}

torch::Tensor
spinner_async_cpu(torch::Tensor rowptr, torch::Tensor col,
            int num_parts, float capacity, float beta, int max_iter,
            float halting_eps, int halting_window,
            bool log, int num_threads) {

    AT_ASSERTM(!rowptr.is_cuda(), "Rowptr tensor must be a CPU tensor");
    AT_ASSERTM(!col.is_cuda(), "Col tensor must be a CPU tensor");
    AT_ASSERTM(num_parts > 0, "num_parts must be greater than 0");

    auto rowptr_data = rowptr.data_ptr<int64_t>();
    auto col_data = col.data_ptr<int64_t>();

    int64_t max_C = (int64_t) (capacity * col.size(0) / num_parts);
    int64_t remaining_C[num_parts] {0};

    AT_ASSERTM(num_parts <= 32768, "num_parts must be less than 32768");
    
    /* Random Init of Partitions */
    std::random_device rd;
    std::mt19937 mersenne_engine{rd()};
    std::uniform_int_distribution<int16_t> dist{0, num_parts-1};
    auto gen = [&dist, &mersenne_engine](){
        return dist(mersenne_engine);
    };
    int16_t* labels = new int16_t[rowptr.size(0)-1]; // int16
    std::generate(labels, labels + (rowptr.size(0)-1), gen);
    // std::vector<int16_t> labels(rowptr.size(0)-1); // int16
    // generate(std::begin(labels), std::end(labels), gen);

    if (log) {
        std::cout << "|| >>>> Async Spinner <<<<" << std::endl;
        std::cout << "|| Initial Partition: ";
        for (int i = 0; i < 3; i++) {
            std::cout << (int)labels[i] << " ";
        }
        std::cout << std::endl;
    }

    std::vector<int16_t> labeled_col(col.size(0));

    // fill col with labels for fast lookup
    // #pragma omp parallel for
    // for (int64_t i = 0; i < col.size(0); i++) {
    //     labeled_col[i] = labels[col_data[i]];
    // }

    double load_per_partition[num_threads][num_parts] {0.0, 0.0};
    
    double agg_load_per_partition[num_parts] {0.0};
    double score_per_partition[num_threads][num_parts] {0.0, 0.0};
    double prev_score_per_partition[num_parts] {0.0};
    
    // outdated: commented due to parallelization
    double penalty_term[num_parts] {0.0};
    
    // interested vertices to migrate
    std::vector<std::vector<std::vector<int64_t>>> interested(num_threads, std::vector<std::vector<int64_t>>(num_parts));
    std::vector<std::vector<double>> interested_sum(num_threads, std::vector<double>(num_parts, 0.0));

    bool halt = false;

    int iter = 0;
    int window = 0;

    double prev_score = 0.0;

    while (iter < max_iter && not halt) {
        // chrono time calc.
        std::chrono::system_clock::time_point start = std::chrono::system_clock::now();
        
        // initialize
        #pragma omp parallel for
        for (int i = 0; i < num_parts; i++) {
            agg_load_per_partition[i] = 0.0;
            penalty_term[i] = 0.0;
            for (int j = 0; j < num_threads; j++) {
                score_per_partition[j][i] = 0.0;
                load_per_partition[j][i] = 0.0;
                interested[j][i].clear();
                interested_sum[j][i] = 0.0;
            }
        }

        #pragma omp parallel for
        for (int64_t i = 0; i < col.size(0); i++) {
            labeled_col[i] = labels[col_data[i]];
        }

        // calculate the load per partition
        #pragma omp parallel for num_threads(num_threads)
        for (int64_t i = 0; i < rowptr.size(0)-1; i++) {
            load_per_partition[omp_get_thread_num()][labels[i]] += rowptr_data[i+1] - rowptr_data[i];
        }

        // aggregate the load per partition
        #pragma omp parallel for
        for (int i = 0; i < num_parts; i++) {
            for (int j = 0; j < num_threads; j++) {
                agg_load_per_partition[i] += load_per_partition[j][i];
            }
        }

        // calculate the penalty term
        #pragma omp parallel for
        for (int i = 0; i < num_parts; i++) {
            penalty_term[i] = agg_load_per_partition[i] / max_C;
            remaining_C[i]  = max_C - agg_load_per_partition[i];
        }

        if (log) {
            std::cout << "|| Iter: " << iter << std::endl;
            std::cout << "|| p (max_load / (|E|/k)): " << \
            *std::max_element(agg_load_per_partition, agg_load_per_partition+num_parts) / (col.size(0) / num_parts) \
            << std::endl;
            // std::cout << "|| Penalty Term: ";
            // for (auto penalty : penalty_term) {
            //     std::cout << penalty << " ";
            // }
            // std::cout << std::endl;
        }

        // calculate the scores of vertices
        // then generate interested vertices which want to migrate

        // todo - we need to parallelize here!
        #pragma omp parallel for num_threads(num_threads)
        for (int64_t i = 0; i < rowptr.size(0)-1; i++) {
            int64_t num_edges = rowptr_data[i+1] - rowptr_data[i];
            if (num_edges == 0) {
                continue; // zero skipping
            }
            int64_t my_part = labels[i];
            std::vector<double> cur_v_scores(num_parts, 0.0); // temporal score vector
            for (int64_t j = rowptr_data[i]; j < rowptr_data[i+1]; j++) {
                cur_v_scores[labeled_col[j]] += 1.0;
            }
            for (int j = 0; j < num_parts; j++) {
                cur_v_scores[j] /= num_edges;
                cur_v_scores[j] = beta + cur_v_scores[j] - beta * penalty_term[j];
            }
            score_per_partition[omp_get_thread_num()][my_part] += cur_v_scores[my_part];

            std::pair<double, std::vector<std::size_t>> max_score_n_indices = find_max_elements(cur_v_scores);
            double max_score = max_score_n_indices.first;
            std::vector<std::size_t> max_indices = max_score_n_indices.second;
            if (std::find(max_indices.begin(), max_indices.end(), my_part) \
                        == max_indices.end()) {
                // if my_part is not the maximum
                // select one from the maximum(s)
                int random_idx = *select_randomly(max_indices.begin(), max_indices.end());
                interested[omp_get_thread_num()][random_idx].push_back(i);
                interested_sum[omp_get_thread_num()][random_idx] += num_edges;
            }
        }

        // intersted vertices are migrated according to the remaining capacity
        // of the target partitions

        // 1) cacluate the migration probability
        double migration_prob[num_threads][num_parts] {0.0};
        #pragma omp parallel for num_threads(num_threads)
        for (int j = 0; j < num_threads; j++) {
            for (int i = 0; i < num_parts; i++) {
                if (interested[j][i].size() == 0 || (int) (remaining_C[i] / num_threads) <= 0) {
                    // we do not migrate if there is no interested vertices
                    // or there is no remaining capacity
                    migration_prob[j][i] = 0.0;
                } else if ((int) (remaining_C[i] / num_threads) >= interested_sum[j][i]) {
                    migration_prob[j][i] = 1.0; // we migrate all interested vertices
                } else { 
                    // we migrate vertices probabilistically
                    migration_prob[j][i] = (remaining_C[i] / num_threads) / interested_sum[j][i];
                }
            }
        }

        // 2) migrate vertices with the calculated probability
        #pragma omp parallel for num_threads(num_threads)
        for (int j = 0; j < num_threads; j++) {
            for (int i = 0; i < num_parts; i++) {
                auto rd = std::random_device {}; 
                auto rng = std::default_random_engine { rd() };
                // shuffle and just pop the first n elements (n = num_migrate)
                std::shuffle(interested[j][i].begin(), interested[j][i].end(), rng);
                int64_t num_migrate = (int64_t) (interested[j][i].size() * migration_prob[j][i]);
                for (int64_t k = 0; k < num_migrate; k++) {
                    if (interested[j][i].size() == 0 || k >= interested[j][i].size()) {
                        break;
                    }
                    labels[interested[j][i][k]] = i;
                    // todo - we need to keep track of the migration
                    // to minimize the traversing overhead
                }
            }
        }


        // calculate the score
        double cur_score = 0.0;
        for (int i = 0; i < num_parts; i++) {
            for (int j = 0; j < num_threads; j++) {
                cur_score += score_per_partition[j][i];
            }
        }

        // calculate the step
        double step = std::abs(1 - cur_score/prev_score);
        if (log) {
            std::cout << "|| Step: " << step << ", Cur Score: " << cur_score \
            << ", Prev Score: " << prev_score << std::endl;
        }
        // check the halting condition
        if (step < halting_eps) {
            window++;
            if (window >= halting_window) {
                halt = true;
            }
        } else {
            window = 0;
        }

        prev_score = cur_score;
        iter++;

        std::chrono::duration<double> iter_time = std::chrono::system_clock::now() - start;

        // report the average reusability
        if (log) {
            // Assume you have sizes:
            size_t rowptr_size = rowptr.size(0);
            size_t col_size = col.size(0);
            size_t num_vertices = rowptr_size - 1;

            // Wrap raw pointers into std::vector objects.
            std::vector<int64_t> rowptr_vec(rowptr_data, rowptr_data + rowptr_size);
            std::vector<int64_t> col_vec(col_data, col_data + col_size);
            std::vector<int16_t> partition_vec(labels, labels + num_vertices);

            // Now call the function.
            double avg_reusability = compute_average_reusability_omp(rowptr_vec, col_vec, partition_vec);
            std::cout << "|| Avg Reusability: " << avg_reusability << std::endl;

            // report time
            std::cout << "|| ============== Time: " << iter_time.count() << " (sec) " << "==============" << std::endl;
        }


        // todo - only focus on change!!! (must be optimized)
    }

    // auto options = torch::TensorOptions().dtype(torch::kInt16);
    // return torch::from_blob(labels.data(), {(int64_t)labels.size()}, options).clone();
    return torch::from_blob(labels, {rowptr.size(0)-1}, torch::kInt16);
}

/*
 * Spinner implementation using the GAS paradigm.
 *
 * Parameters are the same as before.
 */
torch::Tensor spinner_gas_cpu(torch::Tensor rowptr, torch::Tensor col,
                              int num_parts, float capacity, float beta, int max_iter,
                              float halting_eps, int halting_window,
                              bool log, int num_threads) {

    AT_ASSERTM(!rowptr.is_cuda(), "Rowptr tensor must be a CPU tensor");
    AT_ASSERTM(!col.is_cuda(), "Col tensor must be a CPU tensor");
    AT_ASSERTM(num_parts > 0, "num_parts must be greater than 0");
    AT_ASSERTM(num_parts <= 32768, "num_parts must be less than 32768");

    // max_C = capacity * (|E| / num_parts) or, if you prefer, based on total vertex degree.
    int64_t max_C = static_cast<int64_t>(capacity * col.size(0) / num_parts);
    std::vector<int64_t> remaining_C(num_parts, 0);

    // Randomly initialize partitions.
    std::random_device rd;
    std::mt19937 mersenne_engine{rd()};
    std::uniform_int_distribution<int16_t> dist{0, num_parts - 1};
    auto gen = [&dist, &mersenne_engine]() { return dist(mersenne_engine); };
    int64_t num_vertices = rowptr.size(0) - 1;
    int16_t* labels = new int16_t[num_vertices];
    std::generate(labels, labels + num_vertices, gen);

    if (log) {
        std::cout << "|| >>>> GAS Spinner <<<<" << std::endl;
        std::cout << "|| Initial Partitions: ";
        for (int i = 0; i < std::min((int)num_vertices, 3); i++) {
            std::cout << (int)labels[i] << " ";
        }
        std::cout << std::endl;
    }

    // Prepare arrays for the GAS phases.
    double load_per_partition[num_threads][num_parts] = {0.0}; // [thread][part]; max threads assumed 256
    double agg_load_per_partition[num_parts] = {0.0};
    double penalty_term[num_parts] = {0.0};

    // We'll reuse a vector for neighbor lookup: for each edge, store partition label of neighbor.
    std::vector<int16_t> labeled_col(col.size(0));

    int iter = 0, window = 0;
    double prev_score = 0.0;
    bool halt = false;

    // Vectors to store each vertex’s candidate migration (for the Apply phase)
    std::vector<int16_t> candidate_partition(num_vertices, -1);
    std::vector<double> candidate_weight(num_vertices, 0.0);

    // Main loop: each iteration performs Gather, Apply, and Scatter.
    while (iter < max_iter && !halt) {
        auto t_start = std::chrono::system_clock::now();

        // ======= GATHER PHASE =======
        // Reset per–partition accumulators.
        for (int i = 0; i < num_parts; i++) {
            agg_load_per_partition[i] = 0.0;
            penalty_term[i] = 0.0;
            for (int t = 0; t < num_threads; t++) {
                load_per_partition[t][i] = 0.0;
            }
        }

        // Update labeled_col: fast lookup for neighbor partitions.
        int64_t num_edges = col.size(0);
        #pragma omp parallel for num_threads(num_threads)
        for (int64_t i = 0; i < num_edges; i++) {
            // col_data holds indices of vertices; use labels to get neighbor partition.
            int64_t idx = col.data_ptr<int64_t>()[i];
            labeled_col[i] = labels[idx];
        }

        // Compute per–partition load using the degree (difference in rowptr values).
        #pragma omp parallel for num_threads(num_threads)
        for (int64_t i = 0; i < num_vertices; i++) {
            int my_part = labels[i];
            // Here we assume rowptr contains integral degrees (or can be cast to double).
            double deg = static_cast<double>(rowptr.data_ptr<int64_t>()[i+1] - rowptr.data_ptr<int64_t>()[i]);
            load_per_partition[omp_get_thread_num()][my_part] += deg;
        }
        // Aggregate the load from all threads.
        for (int p = 0; p < num_parts; p++) {
            for (int t = 0; t < num_threads; t++) {
                agg_load_per_partition[p] += load_per_partition[t][p];
            }
        }
        // Compute penalty term and remaining capacity.
        for (int p = 0; p < num_parts; p++) {
            penalty_term[p] = agg_load_per_partition[p] / max_C;
            remaining_C[p] = max_C - agg_load_per_partition[p];
        }

        // For each vertex, compute its score vector based on neighbor partitions.
        double global_score = 0.0;
        #pragma omp parallel for reduction(+:global_score) num_threads(num_threads)
        for (int64_t i = 0; i < num_vertices; i++) {
            int64_t start_edge = rowptr.data_ptr<int64_t>()[i];
            int64_t end_edge   = rowptr.data_ptr<int64_t>()[i+1];
            int64_t deg = end_edge - start_edge;
            if (deg == 0) {
                // If vertex has no edges, do nothing.
                candidate_partition[i] = labels[i];
                continue;
            }
            int current_part = labels[i];
            std::vector<double> cur_v_scores(num_parts, 0.0);
            // Gather: accumulate neighbor contributions.
            for (int64_t j = start_edge; j < end_edge; j++) {
                int neighbor_part = labeled_col[j];
                cur_v_scores[neighbor_part] += 1.0;
            }
            // Normalize and adjust by beta and penalty term.
            for (int p = 0; p < num_parts; p++) {
                cur_v_scores[p] /= deg;
                cur_v_scores[p] = beta + cur_v_scores[p] - beta * penalty_term[p];
            }
            // Accumulate the score for the current partition.
            global_score += cur_v_scores[current_part];

            // APPLY PHASE: Decide if the vertex should change partition.
            // If the current partition is not among those with the highest score,
            // then choose a candidate target (break ties randomly).
            auto max_info = find_max_elements(cur_v_scores);
            double max_score = max_info.first;
            std::vector<std::size_t> max_indices = max_info.second;
            if (std::find(max_indices.begin(), max_indices.end(), current_part) == max_indices.end()) {
                // Candidate target is chosen randomly among the max–scoring partitions.
                int chosen = *select_randomly(max_indices.begin(), max_indices.end());
                candidate_partition[i] = chosen;
                candidate_weight[i] = static_cast<double>(deg); // weight can be the vertex degree.
            } else {
                candidate_partition[i] = current_part; // remain in same partition.
                candidate_weight[i] = 0.0;
            }
        } // End Gather–Apply phase

        // ======= APPLY PHASE (Capacity Check) =======
        // For each partition, aggregate the total weight (request) from vertices wishing to migrate.
        std::vector<double> total_request(num_parts, 0.0);
        for (int64_t i = 0; i < num_vertices; i++) {
            if (candidate_partition[i] != labels[i]) {
                total_request[candidate_partition[i]] += candidate_weight[i];
            }
        }
        // Compute migration probability per partition based on remaining capacity.
        std::vector<double> migration_prob(num_parts, 0.0);
        for (int p = 0; p < num_parts; p++) {
            if (total_request[p] > 0.0)
                migration_prob[p] = std::min(1.0, remaining_C[p] / static_cast<double>(total_request[p]));
            else
                migration_prob[p] = 0.0;
        }

        // ======= SCATTER PHASE =======
        // Each vertex with a candidate different from its current partition migrates with the computed probability.
        #pragma omp parallel for num_threads(num_threads)
        for (int64_t i = 0; i < num_vertices; i++) {
            int cand = candidate_partition[i];
            if (cand != labels[i]) {
                double prob = migration_prob[cand];
                double r = static_cast<double>(rand()) / RAND_MAX;
                if (r < prob) {
                    labels[i] = cand;
                }
            }
        }

        // Check for convergence based on the change in global score.
        double step = (prev_score == 0.0) ? 1.0 : std::abs(1.0 - global_score / prev_score);
        if (log) {
            std::cout << "|| Iter: " << iter << ", Global Score: " << global_score
                      << ", Step: " << step << std::endl;
        }
        if (step < halting_eps) {
            window++;
            if (window >= halting_window)
                halt = true;
        } else {
            window = 0;
        }
        prev_score = global_score;
        iter++;

        std::chrono::duration<double> t_elapsed = std::chrono::system_clock::now() - t_start;
        if (log) {
    
            auto rowptr_data = rowptr.data_ptr<int64_t>();
            auto col_data = col.data_ptr<int64_t>();
            
            // Assume you have sizes:
            size_t rowptr_size = rowptr.size(0);
            size_t col_size = col.size(0);
            size_t num_vertices = rowptr_size - 1;

            // Wrap raw pointers into std::vector objects.
            std::vector<int64_t> rowptr_vec(rowptr_data, rowptr_data + rowptr_size);
            std::vector<int64_t> col_vec(col_data, col_data + col_size);
            std::vector<int16_t> partition_vec(labels, labels + num_vertices);

            // Now call the function.
            double avg_reusability = compute_average_reusability_omp(rowptr_vec, col_vec, partition_vec);
            std::cout << "|| Avg Reusability: " << avg_reusability << std::endl;

            std::cout << "|| Iteration time: " << t_elapsed.count() << " sec" << std::endl;
        }
    }

    // Return a PyTorch tensor containing the partition labels.
    auto options = torch::TensorOptions().dtype(torch::kInt16);
    torch::Tensor result = torch::from_blob(labels, {num_vertices}, options).clone();
    // Note: The allocated memory (labels) will be managed by PyTorch after cloning.
    return result;
}

torch::Tensor
grinnder_async_cpu(torch::Tensor rowptr, torch::Tensor col,
            int num_parts, float capacity, float beta, int max_iter,
            float halting_eps, int halting_window, bool reuse_aware, bool refine,
            bool log, int num_threads, torch::optional<torch::Tensor> orig_labels) {

    AT_ASSERTM(!rowptr.is_cuda(), "Rowptr tensor must be a CPU tensor");
    AT_ASSERTM(!col.is_cuda(), "Col tensor must be a CPU tensor");
    AT_ASSERTM(num_parts > 0, "num_parts must be greater than 0");

    auto rowptr_data = rowptr.data_ptr<int64_t>();
    auto col_data = col.data_ptr<int64_t>();

    // max_capacity = capacity * (|V| / k)
    int64_t max_C = (int64_t) (capacity * (rowptr.size(0) - 1) / num_parts);
    int64_t remaining_C[num_parts] {0};

    AT_ASSERTM(num_parts <= 32768, "num_parts must be less than 32768");

    int16_t* labels = new int16_t[rowptr.size(0)-1]; // int16
    // we don't explicitly delete[] labels because pytorch will do that

    if (!refine) {
        /* Random Init of Partitions */
        std::random_device rd;
        std::mt19937 mersenne_engine{rd()};
        std::uniform_int_distribution<int16_t> dist{0, num_parts-1};
        auto gen = [&dist, &mersenne_engine](){
            return dist(mersenne_engine);
        };
        std::generate(labels, labels + (rowptr.size(0)-1), gen);
    } else {
        auto orig_labels_data = orig_labels.value().data_ptr<int16_t>();
        for (int i = 0; i < rowptr.size(0)-1; i++) {
            labels[i] = orig_labels_data[i];
        }
    }
    

    if (log) {
        std::cout << "|| >>>> Async GriNNder <<<<" << std::endl;
        if (reuse_aware) {
            std::cout << "|| Reuse Aware: True" << std::endl;
        }
        std::cout << "|| Initial Partition: ";
        for (int i = 0; i < 3; i++) {
            std::cout << (int)labels[i] << " ";
        }
        std::cout << std::endl;
    }

    std::vector<int16_t> labeled_col(col.size(0));

    double load_per_partition[num_threads][num_parts] {0.0, 0.0};
    
    double agg_load_per_partition[num_parts] {0.0};
    double score_per_partition[num_threads][num_parts] {0.0, 0.0};
    double prev_score_per_partition[num_parts] {0.0};
    
    // outdated: commented due to parallelization
    double penalty_term[num_parts] {0.0};
    
    // interested vertices to migrate
    std::vector<std::vector<std::vector<int64_t>>> interested(num_threads, std::vector<std::vector<int64_t>>(num_parts));
    std::vector<std::vector<double>> interested_sum(num_threads, std::vector<double>(num_parts, 0.0));

    // for grinnder's reuse-awareness
    // this incurs grinnder's overhead in scaling
    std::vector<std::vector<std::vector<int64_t>>> next_prefer(num_threads, std::vector<std::vector<int64_t>>(num_parts));

    bool halt = false;

    int iter = 0;
    int window = 0;

    double prev_score = 0.0;

    while (iter < max_iter && not halt) {
        // chrono time calc.
        std::chrono::system_clock::time_point start = std::chrono::system_clock::now();
        
        // initialize
        #pragma omp parallel for
        for (int i = 0; i < num_parts; i++) {
            agg_load_per_partition[i] = 0.0;
            penalty_term[i] = 0.0;
            for (int j = 0; j < num_threads; j++) {
                score_per_partition[j][i] = 0.0;
                load_per_partition[j][i] = 0.0;
                interested[j][i].clear();
                interested_sum[j][i] = 0.0;
                next_prefer[j][i].clear();
            }
        }

        #pragma omp parallel for
        for (int64_t i = 0; i < col.size(0); i++) {
            labeled_col[i] = labels[col_data[i]];
        }

        // calculate the load per partition
        #pragma omp parallel for num_threads(num_threads)
        for (int64_t i = 0; i < rowptr.size(0)-1; i++) {
            load_per_partition[omp_get_thread_num()][labels[i]] += 1.0;
        }

        // aggregate the load per partition
        #pragma omp parallel for
        for (int i = 0; i < num_parts; i++) {
            for (int j = 0; j < num_threads; j++) {
                agg_load_per_partition[i] += load_per_partition[j][i];
            }
        }

        // calculate the penalty term
        #pragma omp parallel for
        for (int i = 0; i < num_parts; i++) {
            penalty_term[i] = agg_load_per_partition[i] / max_C;
            remaining_C[i]  = max_C - agg_load_per_partition[i];
        }

        if (log) {
            std::cout << "|| Iter: " << iter << std::endl;
            std::cout << "|| p (max_load / (|V|/k)): " << \
            *std::max_element(agg_load_per_partition, agg_load_per_partition+num_parts) / ((rowptr.size(0)-1) / num_parts) \
            << std::endl;
            // std::cout << "|| Penalty Term: ";
            // for (auto penalty : penalty_term) {
            //     std::cout << penalty << " ";
            // }
            // std::cout << std::endl;
        }

        // calculate the scores of vertices
        // then generate interested vertices which want to migrate

        // todo - we need to parallelize here!
        #pragma omp parallel for num_threads(num_threads)
        for (int64_t i = 0; i < rowptr.size(0)-1; i++) {
            int64_t num_edges = rowptr_data[i+1] - rowptr_data[i];
            if (num_edges == 0) {
                continue; // zero skipping
            }
            int64_t my_part = labels[i];
            std::vector<double> cur_v_scores(num_parts, 0.0); // temporal score vector
            for (int64_t j = rowptr_data[i]; j < rowptr_data[i+1]; j++) {
                cur_v_scores[labeled_col[j]] += 1.0;
            }
            for (int j = 0; j < num_parts; j++) {
                cur_v_scores[j] /= num_edges;
                cur_v_scores[j] = beta + cur_v_scores[j] - beta * penalty_term[j];
            }
            score_per_partition[omp_get_thread_num()][my_part] += cur_v_scores[my_part];

            // todo - grinnder scoring is not scalable on #parts
            std::pair<int16_t, int16_t> grinnder_score_indices = grinnder_scoring(cur_v_scores, my_part);
            if (grinnder_score_indices.first != my_part) {
                // if my_part is not the maximum
                interested[omp_get_thread_num()][grinnder_score_indices.first].push_back(i);
                if (reuse_aware) {
                    // reuse_aware: put next prefer partition
                    next_prefer[omp_get_thread_num()][grinnder_score_indices.first].push_back(grinnder_score_indices.second);
                }
                interested_sum[omp_get_thread_num()][grinnder_score_indices.first] += 1; // node weight is 1
            }
        }

        // intersted vertices are migrated according to the remaining capacity
        // of the target partitions

        // 1) cacluate the migration probability
        double migration_prob[num_threads][num_parts] {0.0};
        #pragma omp parallel for num_threads(num_threads)
        for (int j = 0; j < num_threads; j++) {
            for (int i = 0; i < num_parts; i++) {
                if (interested[j][i].size() == 0 || (int) (remaining_C[i] / num_threads) <= 0) {
                    // we do not migrate if there is no interested vertices
                    // or there is no remaining capacity
                    migration_prob[j][i] = 0.0;
                } else if ((int) (remaining_C[i] / num_threads) >= interested_sum[j][i]) {
                    migration_prob[j][i] = 1.0; // we migrate all interested vertices
                } else { 
                    // we migrate vertices probabilistically
                    migration_prob[j][i] = (remaining_C[i] / num_threads) / interested_sum[j][i];
                }
            }
        }

        // 2) migrate vertices with the calculated probability
        #pragma omp parallel for num_threads(num_threads)
        for (int j = 0; j < num_threads; j++) {
            for (int16_t i = 0; i < num_parts; i++) {
                int64_t num_migrate = (int64_t) (interested[j][i].size() * migration_prob[j][i]);
                auto rd = std::random_device {}; 
                auto rng = std::default_random_engine { rd() };
                if (!reuse_aware) { // following the basic spinnner (random)
                    // shuffle and just pop the first n elements (n = num_migrate)
                    std::shuffle(interested[j][i].begin(), interested[j][i].end(), rng);
                    for (int64_t k = 0; k < num_migrate; k++) {
                        if (interested[j][i].size() == 0 || k >= interested[j][i].size()) {
                            break;
                        }
                        labels[interested[j][i][k]] = i;
                        // todo - we need to keep track of the migration
                        // to minimize the traversing overhead
                    }
                } else { // reuse-aware
                    std::pair<int16_t, std::vector<int64_t>> next_prefer_n_indices = grinnder_priority(next_prefer[j][i], num_parts);
                    int64_t next_prefer_idx = next_prefer_n_indices.first;
                    std::vector<int64_t> next_prefer_indices = next_prefer_n_indices.second;
                    std::shuffle(next_prefer_indices.begin(), next_prefer_indices.end(), rng);

                    // now we migrate the reusables first
                    if (next_prefer_indices.size() >= num_migrate) {
                        // if reusables are more than the capacity
                        for (int64_t k = 0; k < num_migrate; k++) {
                            labels[interested[j][i][next_prefer_indices[k]]] = i;
                        }
                    } else {
                        // if reusables are less than the capacity
                        // 1) migrate the reusables
                        for (int64_t k = 0; k < next_prefer_indices.size(); k++) {
                            labels[interested[j][i][next_prefer_indices[k]]] = i;
                            interested[j][i][next_prefer_indices[k]] = -1; // instead of erasing
                            num_migrate--;
                        }

                        // now erase -1s
                        interested[j][i].erase(
                            std::remove_if(interested[j][i].begin(), interested[j][i].end(),
                                [](int64_t x) { return x == -1; }),
                            interested[j][i].end()
                        );

                        // 2) migrate the rest randomly
                        // shuffle and just pop the first n elements (n = num_migrate)
                        std::shuffle(interested[j][i].begin(), interested[j][i].end(), rng);
                        for (int64_t k = 0; k < num_migrate; k++) {
                            if (interested[j][i].size() == 0 || k >= interested[j][i].size()) {
                                break;
                            }
                            labels[interested[j][i][k]] = i;
                            // todo - we need to keep track of the migration
                            // to minimize the traversing overhead
                        }
                    }
                }
            }
        }

        // calculate the score
        double cur_score = 0.0;
        for (int i = 0; i < num_parts; i++) {
            for (int j = 0; j < num_threads; j++) {
                cur_score += score_per_partition[j][i];
            }
        }

        // calculate the step
        double step = std::abs(1 - cur_score/prev_score);
        if (log) {
            std::cout << "|| Step: " << step << ", Cur Score: " << cur_score \
            << ", Prev Score: " << prev_score << std::endl;
        }
        // check the halting condition
        if (step < halting_eps) {
            window++;
            if (window >= halting_window) {
                halt = true;
            }
        } else {
            window = 0;
        }

        prev_score = cur_score;
        iter++;

        std::chrono::duration<double> iter_time = std::chrono::system_clock::now() - start;

        if (log) {
            // Assume you have sizes:
            size_t rowptr_size = rowptr.size(0);
            size_t col_size = col.size(0);
            size_t num_vertices = rowptr_size - 1;

            // Wrap raw pointers into std::vector objects.
            std::vector<int64_t> rowptr_vec(rowptr_data, rowptr_data + rowptr_size);
            std::vector<int64_t> col_vec(col_data, col_data + col_size);
            std::vector<int16_t> partition_vec(labels, labels + num_vertices);

            // Now call the function.
            double avg_reusability = compute_average_reusability_omp(rowptr_vec, col_vec, partition_vec);
            std::cout << "|| Avg Reusability: " << avg_reusability << std::endl;
            std::cout << "|| ============== Time: " << iter_time.count() << " (sec) " << "==============" << std::endl;
        }


        // todo - only focus on change!!! (must be optimized)
    }

    return torch::from_blob(labels, {rowptr.size(0)-1}, torch::kInt16);
}


// Fixed version of grinnder to be aware of cache reuse
// Note that i removed the refine option
torch::Tensor
grinnder_fast_async_cpu(torch::Tensor rowptr, torch::Tensor col,
            int num_parts, float capacity, float beta, int max_iter, int progressive_window,
            float halting_eps, int halting_window, bool reuse_aware, bool refine,
            bool log, int num_threads, torch::optional<torch::Tensor> orig_labels) {

    AT_ASSERTM(!rowptr.is_cuda(), "Rowptr tensor must be a CPU tensor");
    AT_ASSERTM(!col.is_cuda(), "Col tensor must be a CPU tensor");
    AT_ASSERTM(num_parts > 0, "num_parts must be greater than 0");
    // we need to check whether the num_parts is the power of two
    AT_ASSERTM((num_parts & (num_parts - 1)) == 0, "num_parts must be the power of two");
    AT_ASSERTM (progressive_window > 0, "progressive_window must be greater than 0");


    // do we need progressive method?
    bool progressive = false;
    int cur_num_parts = num_parts;
    if (num_parts > num_threads) { // just use empirical one (usually #threads)
        progressive = true;
        cur_num_parts = num_threads;
    }

    auto rowptr_data = rowptr.data_ptr<int64_t>();
    auto col_data = col.data_ptr<int64_t>();


    AT_ASSERTM(num_parts <= 32768, "num_parts must be less than 32768");
    int16_t* labels = new int16_t[rowptr.size(0)-1]; // int16
    // we don't explicitly delete[] labels because pytorch will do that

    /* Random Init of Partitions */
    std::random_device rd;
    std::mt19937 mersenne_engine{rd()};
    std::uniform_int_distribution<int16_t> dist{0, cur_num_parts-1};
    auto gen = [&dist, &mersenne_engine](){
        return dist(mersenne_engine);
    };
    std::generate(labels, labels + (rowptr.size(0)-1), gen);



    if (log) {
        std::cout << "|| >>>> Progressive / Change-Only Access GriNNder <<<<" << std::endl;
        if (reuse_aware) {
            std::cout << "|| Reuse Aware: True" << std::endl;
        }
        std::cout << "|| Initial Partition: ";
        for (int i = 0; i < 3; i++) {
            std::cout << (int)labels[i] << " ";
        }
        std::cout << std::endl;
    }

    /*
    * Note that we'll define the variables with the maximum number of partitions (num_parts)
    * and we'll use the cur_num_parts for the actual number of partitions
    */

    // max_capacity = capacity * (|V| / k)
    int64_t max_C = (int64_t) (capacity * (rowptr.size(0) - 1) / num_parts);
    int64_t remaining_C[num_parts] {0};

    std::vector<int16_t> labeled_col(col.size(0));

    double load_per_partition[num_threads][num_parts] {0.0, 0.0};
    
    double agg_load_per_partition[num_parts] {0.0};
    double score_per_partition[num_threads][num_parts] {0.0, 0.0};
    double prev_score_per_partition[num_parts] {0.0};
    
    // outdated: commented due to parallelization
    double penalty_term[num_parts] {0.0};
    
    // interested vertices to migrate
    std::vector<std::vector<std::vector<int64_t>>> interested(num_threads, std::vector<std::vector<int64_t>>(num_parts));
    std::vector<std::vector<double>> interested_sum(num_threads, std::vector<double>(num_parts, 0.0));

    // for grinnder's reuse-awareness
    std::vector<std::vector<std::vector<int64_t>>> next_prefer(num_threads, std::vector<std::vector<int64_t>>(num_parts));

    bool halt = false;

    int iter = 0;
    int window = 0;

    double prev_score = 0.0;

    while (iter < max_iter && !halt) {
        // chrono time calc.
        std::chrono::system_clock::time_point start = std::chrono::system_clock::now();

        // initialize (just use the current number of partitions)
        for (int i = 0; i < cur_num_parts; i++) {
            agg_load_per_partition[i] = 0.0;
            penalty_term[i] = 0.0;
            for (int j = 0; j < num_threads; j++) {
                score_per_partition[j][i] = 0.0;
                load_per_partition[j][i] = 0.0;
                interested[j][i].clear();
                interested_sum[j][i] = 0.0;
                next_prefer[j][i].clear();
            }
        }

        // fill col with labels for fast lookup
        #pragma omp parallel for schedule(auto)
        for (int64_t i = 0; i < col.size(0); i++) {
            labeled_col[i] = labels[col_data[i]];
        }

        // calculate the load per partition
        #pragma omp parallel for num_threads(num_threads) schedule(auto)
        for (int64_t i = 0; i < rowptr.size(0)-1; i++) {
            load_per_partition[omp_get_thread_num()][labels[i]] += 1.0;
        }

        // aggregate the load per partition
        #pragma omp parallel for
        for (int i = 0; i < cur_num_parts; i++) {
            for (int j = 0; j < num_threads; j++) {
                agg_load_per_partition[i] += load_per_partition[j][i];
            }
        }

        // calculate the penalty term
        #pragma omp parallel for
        for (int i = 0; i < cur_num_parts; i++) {
            penalty_term[i] = agg_load_per_partition[i] / max_C;
            remaining_C[i]  = max_C - agg_load_per_partition[i];
        }

        if (log) {
            std::cout << "|| Iter: " << iter << std::endl;
            std::cout << "|| p (max_load / (|V|/k)): " << \
            *std::max_element(agg_load_per_partition, agg_load_per_partition+cur_num_parts) / ((rowptr.size(0)-1) / cur_num_parts) \
            << std::endl;
        }

        // calculate the scores of vertices
        // then generate interested vertices which want to migrate

        // todo - we need to parallelize here!
        #pragma omp parallel for num_threads(num_threads) schedule(auto)
        for (int64_t i = 0; i < rowptr.size(0)-1; i++) {
            int64_t num_edges = rowptr_data[i+1] - rowptr_data[i];
            if (num_edges == 0) {
                continue; // zero skipping
            }
            int64_t my_part = labels[i];
            std::vector<double> cur_v_scores(cur_num_parts, 0.0); // temporal score vector
            for (int64_t j = rowptr_data[i]; j < rowptr_data[i+1]; j++) {
                cur_v_scores[labeled_col[j]] += 1.0;
            }
            for (int j = 0; j < cur_num_parts; j++) {
                cur_v_scores[j] /= num_edges;
                cur_v_scores[j] = beta + cur_v_scores[j] - beta * penalty_term[j];
            }
            score_per_partition[omp_get_thread_num()][my_part] += cur_v_scores[my_part];

            std::pair<int16_t, int16_t> grinnder_score_indices = grinnder_scoring(cur_v_scores, my_part);
            if (grinnder_score_indices.first != my_part) {
                // if my_part is not the maximum
                interested[omp_get_thread_num()][grinnder_score_indices.first].push_back(i);
                if (reuse_aware) {
                    // reuse_aware: put next prefer partition
                    next_prefer[omp_get_thread_num()][grinnder_score_indices.first].push_back(grinnder_score_indices.second);
                }
                interested_sum[omp_get_thread_num()][grinnder_score_indices.first] += 1; // node weight is 1
            }
        }

        // intersted vertices are migrated according to the remaining capacity
        // of the target partitions

        // 1) cacluate the migration probability
        double migration_prob[num_threads][cur_num_parts] {0.0};
        #pragma omp parallel for num_threads(num_threads) schedule(auto)
        for (int j = 0; j < num_threads; j++) {
            for (int i = 0; i < cur_num_parts; i++) {
                if (interested[j][i].size() == 0 || (int) (remaining_C[i] / num_threads) <= 0) {
                    // we do not migrate if there is no interested vertices
                    // or there is no remaining capacity
                    migration_prob[j][i] = 0.0;
                } else if ((int) (remaining_C[i] / num_threads) >= interested_sum[j][i]) {
                    migration_prob[j][i] = 1.0; // we migrate all interested vertices
                } else { 
                    // we migrate vertices probabilistically
                    migration_prob[j][i] = (remaining_C[i] / num_threads) / interested_sum[j][i];
                }
            }
        }

        // 2) migrate vertices with the calculated probability
        #pragma omp parallel for num_threads(num_threads) schedule(auto)
        for (int j = 0; j < num_threads; j++) {
            for (int16_t i = 0; i < cur_num_parts; i++) {
                int64_t num_migrate = (int64_t) (interested[j][i].size() * migration_prob[j][i]);
                auto rd = std::random_device {}; 
                auto rng = std::default_random_engine { rd() };
                if (!reuse_aware) { // following the basic spinnner (random)
                    // shuffle and just pop the first n elements (n = num_migrate)
                    std::shuffle(interested[j][i].begin(), interested[j][i].end(), rng);
                    for (int64_t k = 0; k < num_migrate; k++) {
                        if (interested[j][i].size() == 0 || k >= interested[j][i].size()) {
                            break;
                        }
                        labels[interested[j][i][k]] = i;
                        // todo - we need to keep track of the migration
                        // to minimize the traversing overhead
                    }
                } else { // reuse-aware
                    std::pair<int16_t, std::vector<int64_t>> next_prefer_n_indices = grinnder_priority(next_prefer[j][i], cur_num_parts);
                    int64_t next_prefer_idx = next_prefer_n_indices.first;
                    std::vector<int64_t> next_prefer_indices = next_prefer_n_indices.second;
                    std::shuffle(next_prefer_indices.begin(), next_prefer_indices.end(), rng);

                    // now we migrate the reusables first
                    if (next_prefer_indices.size() >= num_migrate) {
                        // if reusables are more than the capacity
                        for (int64_t k = 0; k < num_migrate; k++) {
                            labels[interested[j][i][next_prefer_indices[k]]] = i;
                        }
                    } else {
                        // if reusables are less than the capacity
                        // 1) migrate the reusables
                        for (int64_t k = 0; k < next_prefer_indices.size(); k++) {
                            labels[interested[j][i][next_prefer_indices[k]]] = i;
                            interested[j][i][next_prefer_indices[k]] = -1; // instead of erasing
                            num_migrate--;
                        }

                        // now erase -1s
                        interested[j][i].erase(
                            std::remove_if(interested[j][i].begin(), interested[j][i].end(),
                                [](int64_t x) { return x == -1; }),
                            interested[j][i].end()
                        );

                        // 2) migrate the rest randomly
                        // shuffle and just pop the first n elements (n = num_migrate)
                        std::shuffle(interested[j][i].begin(), interested[j][i].end(), rng);
                        for (int64_t k = 0; k < num_migrate; k++) {
                            if (interested[j][i].size() == 0 || k >= interested[j][i].size()) {
                                break;
                            }
                            labels[interested[j][i][k]] = i;
                            // todo - we need to keep track of the migration
                            // to minimize the traversing overhead
                        }
                    }
                }
            }
        }

        // calculate the score
        double cur_score = 0.0;
        for (int i = 0; i < cur_num_parts; i++) {
            for (int j = 0; j < num_threads; j++) {
                cur_score += score_per_partition[j][i];
            }
        }

        // calculate the step
        double step = std::abs(1 - cur_score/prev_score);
        if (log) {
            std::cout << "|| Step: " << step << ", Cur Score: " << cur_score \
            << ", Prev Score: " << prev_score << std::endl;
        }
        // check the halting condition
        if (step < halting_eps) {
            window++;
            if (window >= halting_window) {
                halt = true;
            }
        } else {
            window = 0;
        }

        prev_score = cur_score;
        iter++;

        std::chrono::duration<double> iter_time = std::chrono::system_clock::now() - start;

        if (log) {
            // Assume you have sizes:
            size_t rowptr_size = rowptr.size(0);
            size_t col_size = col.size(0);
            size_t num_vertices = rowptr_size - 1;

            // Wrap raw pointers into std::vector objects.
            std::vector<int64_t> rowptr_vec(rowptr_data, rowptr_data + rowptr_size);
            std::vector<int64_t> col_vec(col_data, col_data + col_size);
            std::vector<int16_t> partition_vec(labels, labels + num_vertices);

            // Now call the function.
            double avg_reusability = compute_average_reusability_omp(rowptr_vec, col_vec, partition_vec);
            std::cout << "|| Avg Reusability: " << avg_reusability << std::endl;
            std::cout << "|| ============== Time: " << iter_time.count() << " (sec) " << "==============" << std::endl;
        }

        // todo - only focus on change!!! (must be optimized)

    }
    return torch::from_blob(labels, {rowptr.size(0)-1}, torch::kInt16);
}

torch::Tensor
random_cpu(torch::Tensor rowptr, torch::Tensor col,
            int num_parts, bool log) {

    AT_ASSERTM(!rowptr.is_cuda(), "Rowptr tensor must be a CPU tensor");
    AT_ASSERTM(!col.is_cuda(), "Col tensor must be a CPU tensor");
    AT_ASSERTM(num_parts > 0, "num_parts must be greater than 0");

    AT_ASSERTM(num_parts <= 32768, "num_parts must be less than 32768");
    
    /* Random Init of Partitions */
    std::random_device rd;
    std::mt19937 mersenne_engine{rd()};
    std::uniform_int_distribution<int16_t> dist{0, num_parts-1};
    auto gen = [&dist, &mersenne_engine](){
        return dist(mersenne_engine);
    };
    int16_t* labels = new int16_t[rowptr.size(0)-1]; // int16
    std::generate(labels, labels + (rowptr.size(0)-1), gen);

    return torch::from_blob(labels, {rowptr.size(0)-1}, torch::kInt16);
    // auto options = torch::TensorOptions().dtype(torch::kInt16);
    // return torch::from_blob(labels.data(), {(int64_t)labels.size()}, options).clone();
}