#ifndef AKMC_H
#define AKMC_H

#include <immintrin.h> // AVX/AVX2
#include <stdlib.h>
#include <math.h>
#include <stdbool.h>
#include <omp.h>
#include <sys/time.h>
#include <cblas.h>

void swap_pointers(float** a, float** b);

int myakm(int num, int dim_o, int c_true, 
    float *X_o, float *Cen_o, int *labels, long long *count_all, long long *time, int max_iter, bool verbose, float tol, int n_threads);

int update_chunk_omp(
    int num, int dim, int clu, int c_true,
    // data
    const float *X,
    float *x_norm,
    float *Cen,
    float *c_norm, 
    // fast var
    Fast_var *Fv,
    // output
    int *labels, 
    long long *count_all,
    float *wc,
    float *wc_new_chunk_all,
    float *Cen_new,
    float *Cen_new_chunk_all,
    float *center_shift,
    long long *time
    );

// void update_chunk(
//     int num, int dim, int clu, int c_true,
//     // data
//     const float *X,
//     const float *Cen,
//     const float *c_norm, 
//     const bool up_cen,
//     // fast var
//     const float *d_cc,
//     const float *COS_OCiCj, const float *SIN_OCiCj,
//     const float *cos_ocx, const float *sin_ocx,
//     const float *r2, 
//     float *d_nc,
//     int *mask_ind,
//     // output
//     int *labels, float *wc,
//     float *Cen_new,
//     int *count);

void average_centers(float *cen, const float *wc, int clu, int dim);

void print_arr(float *a, int num, int dim, int pi, int pj);
bool arr_equal(int *a, int *b, int n);

long long update_mdy(int num, int dim, int clu, int c_true, 
    float *X, float *Cen, float *c_norm, bool up_cen, Fast_var *Fv, int *labels,
    float *wc_new_chunk_all, float *Cen_new_chunk_all);
// long long update_mask_omp(int num, int dim, int clu, int c_true, Fast_var *Fv, int *labels, int *count);
// long long update_Dnc_omp(int num, int dim, int clu, int c_true, float *X, float *Cen, Fast_var *Fv, int *labels);
long long update_y_omp(int num, int dim, int clu, int c_true, 
    float *X, 
    float *Cen, float *c_norm, bool up_cen, Fast_var *Fv,
    int *labels,
    float *wc_new_chunk_all, float *Cen_new_chunk_all
    );

long long update_cen(int clu, int dim,
    float *wc, float *wc_new_chunk_all,
    float *Cen_new, float *Cen_new_chunk_all,
    float *Cen, float *center_shift);

long long update_fv(int num, int dim, int clu, float *X, float *x_norm, float *Cen, float *c_norm, int *labels, Fast_var *Fv);

#endif