#include "fastvar.h"
#include "utils.h"
    

void atomic_add_array(float* target, float* value, int size) {
    for (int i = 0; i < size; i++) {
        #pragma omp atomic
        target[i] += value[i];
    }
}

void compute_mask(float *COS_OCiCj, float *SIN_OCiCj, int ldo, 
                  float *cos_ocx, float *sin_ocx, int num,   
                  float *d_cc, int ldd, float *r2, int *y, 
                  int *mask_ind, int ldm,    // mask_ind: ininit by -1
                  int max_c, 
                  int *count
                  ){

    int clu = ldo;
    int i, j;
    float one = 1.0;
    __m256 cos_occ_vec, sin_occ_vec;
    __m256 cos_ocxi_vec, sin_ocxi_vec;
    __m256 ret1, ret3;
    __m256 frr_vec;
    __m256 dcc_vec;

    int k, g;
    float frr;
    float ri, maskij;
    float *maski8 = malloc_align(8);
    float *dg;
    for (i = 0; i < num; i++){
        k = 0;
        g = y[i];
        frr = 4 * r2[i];
        cos_ocxi_vec = _mm256_broadcast_ss(cos_ocx + i); // 直接广播常量
        sin_ocxi_vec = _mm256_broadcast_ss(sin_ocx + i); // 直接广播常量
        frr_vec = _mm256_broadcast_ss(&frr); // 直接广播常量

        dg = d_cc + g * ldd;
        for (j = 0; j + 8 <= max_c; j+=8){
            cos_occ_vec = _mm256_load_ps(&COS_OCICJ(g, j));
            sin_occ_vec = _mm256_load_ps(&SIN_OCICJ(g, j));


            ret1 = _mm256_mul_ps(cos_occ_vec, cos_ocxi_vec);
            ret1 = _mm256_fmadd_ps(sin_occ_vec, sin_ocxi_vec, ret1);
            ret1 = _mm256_mul_ps(ret1, ret1);
            ret1 = _mm256_mul_ps(ret1, frr_vec);
            _mm256_store_ps(maski8, ret1);

            if ( (j + 0 == g) || maski8[0] > dg[j + 0]){
                MID(i, k) = j + 0;
                k++;
            }
            if ( (j + 1 == g) || maski8[1] > dg[j + 1]){
                MID(i, k) = j + 1;
                k++;
            }
            if ( (j + 2 == g) || maski8[2] > dg[j + 2]){
                MID(i, k) = j + 2;
                k++;
            }
            if ( (j + 3 == g) || maski8[3] > dg[j + 3]){
                MID(i, k) = j + 3;
                k++;
            }
            if ( (j + 4 == g) || maski8[4] > dg[j + 4]){
                MID(i, k) = j + 4;
                k++;
            }
            if ( (j + 5 == g) || maski8[5] > dg[j + 5]){
                MID(i, k) = j + 5;
                k++;
            }
            if ( (j + 6 == g) || maski8[6] > dg[j + 6]){
                MID(i, k) = j + 6;
                k++;
            }
            if ( (j + 7 == g) || maski8[7] > dg[j + 7]){
                MID(i, k) = j + 7;
                k++;
            }

        }

        for (; j < max_c; j++){
            maskij = COS_OCICJ(g, j) * cos_ocx[i] + SIN_OCICJ(g, j) * sin_ocx[i];
            maskij = maskij * maskij * frr; // - sqrtf(D_CC(g, j))/(2 * ri);
            if ( (j == g) || maskij > dg[j]){
                MID(i, k) = j;
                k++;
            }
        }

        count[i] = k;

        if (k < clu){
            for (int tmpi = k; tmpi < min(k + 8, clu); tmpi++){
                MID(i, tmpi) = -1;
            }
        }
    }
    free(maski8);
}


void compute_COS_OCX(float *c_norm, int clu, float *r2, int *y, float *x_norm, float *cos_ocx, float *sin_ocx, int num){
    int i;
    float *cy = aligned_alloc(32, 8 * sizeof(float));

    __m256 cy_vec, r2_vec, x_norm_vec, res, one_vec, two_vec;

    float two = 2.0;
    two_vec = _mm256_broadcast_ss(&two);

    float one = 1.0;
    one_vec = _mm256_broadcast_ss(&one);

    for (i = 0; i + 8 <= num; i+=8){
        cy[0] = c_norm[y[i]];
        cy[1] = c_norm[y[i+1]];
        cy[2] = c_norm[y[i+2]];
        cy[3] = c_norm[y[i+3]];
        cy[4] = c_norm[y[i+4]];
        cy[5] = c_norm[y[i+5]];
        cy[6] = c_norm[y[i+6]];
        cy[7] = c_norm[y[i+7]];

        cy_vec = _mm256_load_ps(cy);   // d_cc[i, 0].... d_cc[i, 7]
        r2_vec = _mm256_load_ps(r2 + i);
        x_norm_vec = _mm256_load_ps(x_norm + i);

        x_norm_vec = _mm256_sub_ps(cy_vec, x_norm_vec);
        x_norm_vec = _mm256_add_ps(x_norm_vec, r2_vec);

        res = _mm256_mul_ps(cy_vec, r2_vec);
        res = _mm256_sqrt_ps(res);
        res = _mm256_mul_ps(res, two_vec);

        res = _mm256_div_ps(x_norm_vec, res);

        _mm256_store_ps(cos_ocx + i,   res);

        // compute sin ocx = sqrt( 1 - cos ocx^2 )
        res = _mm256_mul_ps(res, res);
        res = _mm256_sub_ps(one_vec, res);
        res = _mm256_sqrt_ps(res);
        _mm256_store_ps(sin_ocx + i, res);

        // ocx[i]   = (cy0 + r2[i]   - x_norm[i])   / (2 * sqrtf(cy0 * r2[i]  ));
        // ocx[i+1] = (cy1 + r2[i+1] - x_norm[i+1]) / (2 * sqrtf(cy1 * r2[i+1]));
        // ocx[i+2] = (cy2 + r2[i+2] - x_norm[i+2]) / (2 * sqrtf(cy2 * r2[i+2]));
        // ocx[i+3] = (cy3 + r2[i+3] - x_norm[i+3]) / (2 * sqrtf(cy3 * r2[i+3]));
        // ocx[i+4] = (cy4 + r2[i+4] - x_norm[i+4]) / (2 * sqrtf(cy4 * r2[i+4]));
        // ocx[i+5] = (cy5 + r2[i+5] - x_norm[i+5]) / (2 * sqrtf(cy5 * r2[i+5]));
        // ocx[i+6] = (cy6 + r2[i+6] - x_norm[i+6]) / (2 * sqrtf(cy6 * r2[i+6]));
        // ocx[i+7] = (cy7 + r2[i+7] - x_norm[i+7]) / (2 * sqrtf(cy7 * r2[i+7]));
    }

    int g;
    for (; i < num; i++){
        g = y[i];
        float cy0 = c_norm[g];
        cos_ocx[i] = (cy0 + r2[i]   - x_norm[i])   / (2 * sqrtf(cy0 * r2[i]  ));
        sin_ocx[i] = sqrtf(1 - cos_ocx[i] * cos_ocx[i]);
    }
    free(cy);
    // printf("cos end\n");
}


void compute_r2_omp(int num, int dim, float *X, float *Cen, int *labels, float *r2){

    OMP_CHUNK NC;
    init_chunk_par(&NC, num, 256);

    int max_threads = omp_get_max_threads();
    int effective_threads = (NC.n_chunks < max_threads) ? NC.n_chunks : max_threads;

    #pragma omp parallel num_threads(effective_threads)
    {
        // r2
        #pragma omp for schedule(static) 
        for (int chunk_idx = 0; chunk_idx < NC.n_chunks; chunk_idx++){
            int start_row = chunk_idx * NC.n_samples_chunk;
            int end_row = start_row + NC.n_samples_chunk;
            if (end_row > num) end_row = num;
            int rows = end_row - start_row;
            if (rows > 0) {
                compute_r2(X + start_row * dim, rows, dim, Cen, labels + start_row, r2 + start_row);
                // compute_COS_OCX(c_norm, clu, Fv->r2 + start_row, labels + start_row, x_norm + start_row, 
                    // Fv->cos_ocx + start_row, Fv->sin_ocx + start_row, rows);
            }
        }
    }
}

void gemm_myblas(int num, int clu, int dim, float alpha, float *X, float *C, float *D){
    for (int i = 0; i < num; i++){
        for (int j = 0; j < clu; j+=8){
            D[i * clu + j + 0] = alpha * cblas_sdot(dim, X + i * dim, 1, C + (j + 0) * dim, 1);
            D[i * clu + j + 1] = alpha * cblas_sdot(dim, X + i * dim, 1, C + (j + 1) * dim, 1);
            D[i * clu + j + 2] = alpha * cblas_sdot(dim, X + i * dim, 1, C + (j + 2) * dim, 1);
            D[i * clu + j + 3] = alpha * cblas_sdot(dim, X + i * dim, 1, C + (j + 3) * dim, 1);
            D[i * clu + j + 4] = alpha * cblas_sdot(dim, X + i * dim, 1, C + (j + 4) * dim, 1);
            D[i * clu + j + 5] = alpha * cblas_sdot(dim, X + i * dim, 1, C + (j + 5) * dim, 1);
            D[i * clu + j + 6] = alpha * cblas_sdot(dim, X + i * dim, 1, C + (j + 6) * dim, 1);
            D[i * clu + j + 7] = alpha * cblas_sdot(dim, X + i * dim, 1, C + (j + 7) * dim, 1);
        }
    }
}


void compute_COS_OCiCj_part(float *c_norm, int true_i, int rows, int clu, float *d_cc, int ldd, float *COS_OCiCj, float *SIN_OCiCj, int ldo){
    __m256 c0_vec, c1_vec, c2_vec, c3_vec;
    __m256 sq_c0_vec, sq_c1_vec, sq_c2_vec, sq_c3_vec;

    __m256 d0_vec, d1_vec, d2_vec, d3_vec;
    __m256 sq_d0_vec, sq_d1_vec, sq_d2_vec, sq_d3_vec;

    __m256 cj_vec;

    int i, j;
    for (i = true_i; i < rows + true_i; i++){
        c0_vec = _mm256_broadcast_ss(c_norm + i); // 直接广播常量
        // c1_vec = _mm256_broadcast_ss(c_norm + i+1); // 直接广播常量
        // c2_vec = _mm256_broadcast_ss(c_norm + i+2); // 直接广播常量
        // c3_vec = _mm256_broadcast_ss(c_norm + i+3); // 直接广播常量

        sq_c0_vec = _mm256_sqrt_ps(c0_vec);
        // sq_c1_vec = _mm256_sqrt_ps(c1_vec);
        // sq_c2_vec = _mm256_sqrt_ps(c2_vec);
        // sq_c3_vec = _mm256_sqrt_ps(c3_vec);

        for (j = 0; j < clu; j+=8){
            d0_vec = _mm256_load_ps(&D_CC(i, j));   // d_cc[i, 0].... d_cc[i, 7]
            // d1_vec = _mm256_load_ps(&D_CC(i+1, j));   // d_cc[i, 0].... d_cc[i, 7]
            // d2_vec = _mm256_load_ps(&D_CC(i+2, j));   // d_cc[i, 0].... d_cc[i, 7]
            // d3_vec = _mm256_load_ps(&D_CC(i+3, j));   // d_cc[i, 0].... d_cc[i, 7]

            sq_d0_vec = _mm256_sqrt_ps(d0_vec);
            // sq_d1_vec = _mm256_sqrt_ps(d1_vec);
            // sq_d2_vec = _mm256_sqrt_ps(d2_vec);
            // sq_d3_vec = _mm256_sqrt_ps(d3_vec);

            cj_vec = _mm256_load_ps(c_norm + j);    // c_norm[0] ... c_norm[7]

            d0_vec = _mm256_add_ps(c0_vec, d0_vec); // d_cc[i, j] += c_norm[i]
            // d1_vec = _mm256_add_ps(c1_vec, d1_vec); // d_cc[i, j] += c_norm[i]
            // d2_vec = _mm256_add_ps(c2_vec, d2_vec); // d_cc[i, j] += c_norm[i]
            // d3_vec = _mm256_add_ps(c3_vec, d3_vec); // d_cc[i, j] += c_norm[i]

            d0_vec = _mm256_sub_ps(d0_vec, cj_vec); // d_cc[i, j] += c_norm[j]
            // d1_vec = _mm256_sub_ps(d1_vec, cj_vec); // d_cc[i, j] += c_norm[j]
            // d2_vec = _mm256_sub_ps(d2_vec, cj_vec); // d_cc[i, j] += c_norm[j]
            // d3_vec = _mm256_sub_ps(d3_vec, cj_vec); // d_cc[i, j] += c_norm[j]
            // now d0_vec is numerator

            sq_d0_vec = _mm256_mul_ps(sq_d0_vec, sq_c0_vec);
            // sq_d1_vec = _mm256_mul_ps(sq_d1_vec, sq_c1_vec);
            // sq_d2_vec = _mm256_mul_ps(sq_d2_vec, sq_c2_vec);
            // sq_d3_vec = _mm256_mul_ps(sq_d3_vec, sq_c3_vec);

            d0_vec = _mm256_div_ps(d0_vec, sq_d0_vec);
            // d1_vec = _mm256_div_ps(d1_vec, sq_d1_vec);
            // d2_vec = _mm256_div_ps(d2_vec, sq_d2_vec);
            // d3_vec = _mm256_div_ps(d3_vec, sq_d3_vec);

            float halfv = 0.5;
            sq_d0_vec = _mm256_broadcast_ss(&halfv);
            d0_vec = _mm256_mul_ps(d0_vec, sq_d0_vec);
            // d1_vec = _mm256_mul_ps(d1_vec, sq_d0_vec);
            // d2_vec = _mm256_mul_ps(d2_vec, sq_d0_vec);
            // d3_vec = _mm256_mul_ps(d3_vec, sq_d0_vec);

            _mm256_store_ps(&COS_OCICJ(i, j),   d0_vec);
            // _mm256_store_ps(&COS_OCICJ(i+1, j), d1_vec);
            // _mm256_store_ps(&COS_OCICJ(i+2, j), d2_vec);
            // _mm256_store_ps(&COS_OCICJ(i+3, j), d3_vec);

            // compute sin ocx = sqrt( 1 - cos ocx^2 )
            d0_vec = _mm256_mul_ps(d0_vec, d0_vec);
            // d1_vec = _mm256_mul_ps(d1_vec, d1_vec);
            // d2_vec = _mm256_mul_ps(d2_vec, d2_vec);
            // d3_vec = _mm256_mul_ps(d3_vec, d3_vec);

            float onev = 1.0;
            sq_d0_vec = _mm256_broadcast_ss(&onev);
            d0_vec = _mm256_sub_ps(sq_d0_vec, d0_vec);
            // d1_vec = _mm256_sub_ps(sq_d0_vec, d1_vec);
            // d2_vec = _mm256_sub_ps(sq_d0_vec, d2_vec);
            // d3_vec = _mm256_sub_ps(sq_d0_vec, d3_vec);

            d0_vec = _mm256_sqrt_ps(d0_vec);
            // d1_vec = _mm256_sqrt_ps(d1_vec);
            // d2_vec = _mm256_sqrt_ps(d2_vec);
            // d3_vec = _mm256_sqrt_ps(d3_vec);

            _mm256_store_ps(&SIN_OCICJ(i, j),   d0_vec);
            // _mm256_store_ps(&SIN_OCICJ(i+1, j), d1_vec);
            // _mm256_store_ps(&SIN_OCICJ(i+2, j), d2_vec);
            // _mm256_store_ps(&SIN_OCICJ(i+3, j), d3_vec);
        }
    }

}
// this func cannot be called directly, it must be called through the compute_COS_OCiCj_py function
void compute_COS_OCiCj(
    float *c_norm,       // out, 
    int clu,             // IN, num of clusters, len of c_norm
    float *d_cc,         // OUT, row-major, clu x clu, 
    int ldd,
    float *COS_OCiCj,        // OUT, may be nan
    float *SIN_OCiCj,
    int ldo
    ){
    // c must be aligned, d_cc must be aligned
    
    // D_CC = -2 C * C.T
    // gemm_mul_p(clu, clu, ldc, -2.0f, c, ldc, c, ldc, d_cc, ldd);

    // c_norm, || ci ||^2
    int i, j;
    for (i = 0; i < clu; i += 8){
        c_norm[i]   = -0.5 * D_CC(i,   i);
        c_norm[i+1] = -0.5 * D_CC(i+1, i+1);
        c_norm[i+2] = -0.5 * D_CC(i+2, i+2);
        c_norm[i+3] = -0.5 * D_CC(i+3, i+3);
        c_norm[i+4] = -0.5 * D_CC(i+4, i+4);
        c_norm[i+5] = -0.5 * D_CC(i+5, i+5);
        c_norm[i+6] = -0.5 * D_CC(i+6, i+6);
        c_norm[i+7] = -0.5 * D_CC(i+7, i+7);
    }

    // printf("befor refine\n");
    // print_arr(c_norm, 1, clu, 1, 5);

    refine_D_CC(d_cc, ldd, c_norm, clu);

    __m256 c0_vec, c1_vec, c2_vec, c3_vec;
    __m256 sq_c0_vec, sq_c1_vec, sq_c2_vec, sq_c3_vec;

    __m256 d0_vec, d1_vec, d2_vec, d3_vec;
    __m256 sq_d0_vec, sq_d1_vec, sq_d2_vec, sq_d3_vec;

    __m256 cj_vec;
    // COS_OCgCj
    for (i = 0; i < clu; i+=4){
        c0_vec = _mm256_broadcast_ss(c_norm + i); // 直接广播常量
        c1_vec = _mm256_broadcast_ss(c_norm + i+1); // 直接广播常量
        c2_vec = _mm256_broadcast_ss(c_norm + i+2); // 直接广播常量
        c3_vec = _mm256_broadcast_ss(c_norm + i+3); // 直接广播常量

        sq_c0_vec = _mm256_sqrt_ps(c0_vec);
        sq_c1_vec = _mm256_sqrt_ps(c1_vec);
        sq_c2_vec = _mm256_sqrt_ps(c2_vec);
        sq_c3_vec = _mm256_sqrt_ps(c3_vec);

        for (j = 0; j < clu; j+=8){
            d0_vec = _mm256_load_ps(&D_CC(i, j));   // d_cc[i, 0].... d_cc[i, 7]
            d1_vec = _mm256_load_ps(&D_CC(i+1, j));   // d_cc[i, 0].... d_cc[i, 7]
            d2_vec = _mm256_load_ps(&D_CC(i+2, j));   // d_cc[i, 0].... d_cc[i, 7]
            d3_vec = _mm256_load_ps(&D_CC(i+3, j));   // d_cc[i, 0].... d_cc[i, 7]

            sq_d0_vec = _mm256_sqrt_ps(d0_vec);
            sq_d1_vec = _mm256_sqrt_ps(d1_vec);
            sq_d2_vec = _mm256_sqrt_ps(d2_vec);
            sq_d3_vec = _mm256_sqrt_ps(d3_vec);

            cj_vec = _mm256_load_ps(c_norm + j);    // c_norm[0] ... c_norm[7]

            d0_vec = _mm256_add_ps(c0_vec, d0_vec); // d_cc[i, j] += c_norm[i]
            d1_vec = _mm256_add_ps(c1_vec, d1_vec); // d_cc[i, j] += c_norm[i]
            d2_vec = _mm256_add_ps(c2_vec, d2_vec); // d_cc[i, j] += c_norm[i]
            d3_vec = _mm256_add_ps(c3_vec, d3_vec); // d_cc[i, j] += c_norm[i]

            d0_vec = _mm256_sub_ps(d0_vec, cj_vec); // d_cc[i, j] += c_norm[j]
            d1_vec = _mm256_sub_ps(d1_vec, cj_vec); // d_cc[i, j] += c_norm[j]
            d2_vec = _mm256_sub_ps(d2_vec, cj_vec); // d_cc[i, j] += c_norm[j]
            d3_vec = _mm256_sub_ps(d3_vec, cj_vec); // d_cc[i, j] += c_norm[j]
            // now d0_vec is numerator

            sq_d0_vec = _mm256_mul_ps(sq_d0_vec, sq_c0_vec);
            sq_d1_vec = _mm256_mul_ps(sq_d1_vec, sq_c1_vec);
            sq_d2_vec = _mm256_mul_ps(sq_d2_vec, sq_c2_vec);
            sq_d3_vec = _mm256_mul_ps(sq_d3_vec, sq_c3_vec);

            d0_vec = _mm256_div_ps(d0_vec, sq_d0_vec);
            d1_vec = _mm256_div_ps(d1_vec, sq_d1_vec);
            d2_vec = _mm256_div_ps(d2_vec, sq_d2_vec);
            d3_vec = _mm256_div_ps(d3_vec, sq_d3_vec);

            float halfv = 0.5;
            sq_d0_vec = _mm256_broadcast_ss(&halfv);
            d0_vec = _mm256_mul_ps(d0_vec, sq_d0_vec);
            d1_vec = _mm256_mul_ps(d1_vec, sq_d0_vec);
            d2_vec = _mm256_mul_ps(d2_vec, sq_d0_vec);
            d3_vec = _mm256_mul_ps(d3_vec, sq_d0_vec);

            _mm256_store_ps(&COS_OCICJ(i, j),   d0_vec);
            _mm256_store_ps(&COS_OCICJ(i+1, j), d1_vec);
            _mm256_store_ps(&COS_OCICJ(i+2, j), d2_vec);
            _mm256_store_ps(&COS_OCICJ(i+3, j), d3_vec);

            // compute sin ocx = sqrt( 1 - cos ocx^2 )
            d0_vec = _mm256_mul_ps(d0_vec, d0_vec);
            d1_vec = _mm256_mul_ps(d1_vec, d1_vec);
            d2_vec = _mm256_mul_ps(d2_vec, d2_vec);
            d3_vec = _mm256_mul_ps(d3_vec, d3_vec);

            float onev = 1.0;
            sq_d0_vec = _mm256_broadcast_ss(&onev);
            d0_vec = _mm256_sub_ps(sq_d0_vec, d0_vec);
            d1_vec = _mm256_sub_ps(sq_d0_vec, d1_vec);
            d2_vec = _mm256_sub_ps(sq_d0_vec, d2_vec);
            d3_vec = _mm256_sub_ps(sq_d0_vec, d3_vec);

            d0_vec = _mm256_sqrt_ps(d0_vec);
            d1_vec = _mm256_sqrt_ps(d1_vec);
            d2_vec = _mm256_sqrt_ps(d2_vec);
            d3_vec = _mm256_sqrt_ps(d3_vec);

            _mm256_store_ps(&SIN_OCICJ(i, j),   d0_vec);
            _mm256_store_ps(&SIN_OCICJ(i+1, j), d1_vec);
            _mm256_store_ps(&SIN_OCICJ(i+2, j), d2_vec);
            _mm256_store_ps(&SIN_OCICJ(i+3, j), d3_vec);
        }
    }
}


// INPUT: D_CC, -2 C * C.T
// OUTPUT: D_CC(i, j) += c_norm[i] + c_norm[j]
void refine_D_CC(float *d_cc, int ldd, float *c_norm, int clu){
    // D_CC(i, j) = ||ci - cj||^2
    __m256 d0_vec, d1_vec, d2_vec, d3_vec, d4_vec, d5_vec, d6_vec, d7_vec;
    __m256 c0_vec, c1_vec, c2_vec, c3_vec, c4_vec, c5_vec, c6_vec, c7_vec;
    __m256 cj_vec;

    int i, j;
    for (i = 0; i < clu; i+=8){
        c0_vec = _mm256_broadcast_ss(c_norm + i); // 直接广播常量
        c1_vec = _mm256_broadcast_ss(c_norm + i+1); // 直接广播常量
        c2_vec = _mm256_broadcast_ss(c_norm + i+2); // 直接广播常量
        c3_vec = _mm256_broadcast_ss(c_norm + i+3); // 直接广播常量
        c4_vec = _mm256_broadcast_ss(c_norm + i+4); // 直接广播常量
        c5_vec = _mm256_broadcast_ss(c_norm + i+5); // 直接广播常量
        c6_vec = _mm256_broadcast_ss(c_norm + i+6); // 直接广播常量
        c7_vec = _mm256_broadcast_ss(c_norm + i+7); // 直接广播常量

        for (j = 0; j < clu; j+=8){
            d0_vec = _mm256_load_ps(&D_CC(i, j));   // d_cc[i, 0].... d_cc[i, 7]
            d1_vec = _mm256_load_ps(&D_CC(i+1, j));   // d_cc[i, 0].... d_cc[i, 7]
            d2_vec = _mm256_load_ps(&D_CC(i+2, j));   // d_cc[i, 0].... d_cc[i, 7]
            d3_vec = _mm256_load_ps(&D_CC(i+3, j));   // d_cc[i, 0].... d_cc[i, 7]
            d4_vec = _mm256_load_ps(&D_CC(i+4, j));   // d_cc[i, 0].... d_cc[i, 7]
            d5_vec = _mm256_load_ps(&D_CC(i+5, j));   // d_cc[i, 0].... d_cc[i, 7]
            d6_vec = _mm256_load_ps(&D_CC(i+6, j));   // d_cc[i, 0].... d_cc[i, 7]
            d7_vec = _mm256_load_ps(&D_CC(i+7, j));   // d_cc[i, 0].... d_cc[i, 7]

            cj_vec = _mm256_load_ps(c_norm + j);    // c_norm[0] ... c_norm[7]

            d0_vec = _mm256_add_ps(cj_vec, d0_vec); // d_cc[i, j] += c_norm[j]
            d1_vec = _mm256_add_ps(cj_vec, d1_vec); // d_cc[i, j] += c_norm[j]
            d2_vec = _mm256_add_ps(cj_vec, d2_vec); // d_cc[i, j] += c_norm[j]
            d3_vec = _mm256_add_ps(cj_vec, d3_vec); // d_cc[i, j] += c_norm[j]
            d4_vec = _mm256_add_ps(cj_vec, d4_vec); // d_cc[i, j] += c_norm[j]
            d5_vec = _mm256_add_ps(cj_vec, d5_vec); // d_cc[i, j] += c_norm[j]
            d6_vec = _mm256_add_ps(cj_vec, d6_vec); // d_cc[i, j] += c_norm[j]
            d7_vec = _mm256_add_ps(cj_vec, d7_vec); // d_cc[i, j] += c_norm[j]

            d0_vec = _mm256_add_ps(c0_vec, d0_vec); // d_cc[i, j] += c_norm[i]
            d1_vec = _mm256_add_ps(c1_vec, d1_vec); // d_cc[i, j] += c_norm[i]
            d2_vec = _mm256_add_ps(c2_vec, d2_vec); // d_cc[i, j] += c_norm[i]
            d3_vec = _mm256_add_ps(c3_vec, d3_vec); // d_cc[i, j] += c_norm[i]
            d4_vec = _mm256_add_ps(c4_vec, d4_vec); // d_cc[i, j] += c_norm[i]
            d5_vec = _mm256_add_ps(c5_vec, d5_vec); // d_cc[i, j] += c_norm[i]
            d6_vec = _mm256_add_ps(c6_vec, d6_vec); // d_cc[i, j] += c_norm[i]
            d7_vec = _mm256_add_ps(c7_vec, d7_vec); // d_cc[i, j] += c_norm[i]

            _mm256_store_ps(&D_CC(i, j),   d0_vec);
            _mm256_store_ps(&D_CC(i+1, j), d1_vec);
            _mm256_store_ps(&D_CC(i+2, j), d2_vec);
            _mm256_store_ps(&D_CC(i+3, j), d3_vec);
            _mm256_store_ps(&D_CC(i+4, j), d4_vec);
            _mm256_store_ps(&D_CC(i+5, j), d5_vec);
            _mm256_store_ps(&D_CC(i+6, j), d6_vec);
            _mm256_store_ps(&D_CC(i+7, j), d7_vec);
        }

        D_CC(i,   i) = 0;
        D_CC(i+1, i+1) = 0;
        D_CC(i+2, i+2) = 0;
        D_CC(i+3, i+3) = 0;
        D_CC(i+4, i+4) = 0;
        D_CC(i+5, i+5) = 0;
        D_CC(i+6, i+6) = 0;
        D_CC(i+7, i+7) = 0;
    }
}


// mask_ind: row_major, clu % 8 == 0
void mask_full(int *mask_ind, int num, int clu){
    int i, j;
    for (i = 0; i < num; i++){
        for (j = 0; j < clu; j++){
            mask_ind[i*clu + j] = j;
        }
    }
}

void compute_r2(float *x, int num, int dim, float *cen, int *y, float *r2){
    int i, j, g0, g1, g2, g3, g4, g5, g6, g7;
    float s, tmp;
    __m256 x0_vec, x1_vec, x2_vec, x3_vec, x4_vec, x5_vec, x6_vec, x7_vec;
    __m256 c0_vec, c1_vec, c2_vec, c3_vec, c4_vec, c5_vec, c6_vec, c7_vec;
    __m256 r0_vec, r1_vec, r2_vec, r3_vec, r4_vec, r5_vec, r6_vec, r7_vec;

    float *x0, *x1, *x2, *x3, *x4, *x5, *x6, *x7; 
    float *c0, *c1, *c2, *c3, *c4, *c5, *c6, *c7;

    for (i = 0; i + 8 <= num; i+=8){

        g0 = y[i + 0];
        g1 = y[i + 1];
        g2 = y[i + 2];
        g3 = y[i + 3];
        g4 = y[i + 4];
        g5 = y[i + 5];
        g6 = y[i + 6];
        g7 = y[i + 7];

        r0_vec = _mm256_setzero_ps();
        r1_vec = _mm256_setzero_ps();
        r2_vec = _mm256_setzero_ps();
        r3_vec = _mm256_setzero_ps();
        r4_vec = _mm256_setzero_ps();
        r5_vec = _mm256_setzero_ps();
        r6_vec = _mm256_setzero_ps();
        r7_vec = _mm256_setzero_ps();

        // x[i, :], cen[g, :]
        x0 = x + (i + 0) * dim;
        x1 = x + (i + 1) * dim;
        x2 = x + (i + 2) * dim;
        x3 = x + (i + 3) * dim;
        x4 = x + (i + 4) * dim;
        x5 = x + (i + 5) * dim;
        x6 = x + (i + 6) * dim;
        x7 = x + (i + 7) * dim;

        c0 = cen + g0 * dim;
        c1 = cen + g1 * dim;
        c2 = cen + g2 * dim;
        c3 = cen + g3 * dim;
        c4 = cen + g4 * dim;
        c5 = cen + g5 * dim;
        c6 = cen + g6 * dim;
        c7 = cen + g7 * dim;

        for (j = 0; j < dim; j+=8){
            x0_vec = _mm256_load_ps(x0);   // x[i, 0].... x[i, 7]
            x1_vec = _mm256_load_ps(x1);   // x[i, 0].... x[i, 7]
            x2_vec = _mm256_load_ps(x2);   // x[i, 0].... x[i, 7]
            x3_vec = _mm256_load_ps(x3);   // x[i, 0].... x[i, 7]
            x4_vec = _mm256_load_ps(x4);   // x[i, 0].... x[i, 7]
            x5_vec = _mm256_load_ps(x5);   // x[i, 0].... x[i, 7]
            x6_vec = _mm256_load_ps(x6);   // x[i, 0].... x[i, 7]
            x7_vec = _mm256_load_ps(x7);   // x[i, 0].... x[i, 7]

            c0_vec = _mm256_load_ps(c0);   // x[i, 0].... x[i, 7]
            c1_vec = _mm256_load_ps(c1);   // x[i, 0].... x[i, 7]
            c2_vec = _mm256_load_ps(c2);   // x[i, 0].... x[i, 7]
            c3_vec = _mm256_load_ps(c3);   // x[i, 0].... x[i, 7]
            c4_vec = _mm256_load_ps(c4);   // x[i, 0].... x[i, 7]
            c5_vec = _mm256_load_ps(c5);   // x[i, 0].... x[i, 7]
            c6_vec = _mm256_load_ps(c6);   // x[i, 0].... x[i, 7]
            c7_vec = _mm256_load_ps(c7);   // x[i, 0].... x[i, 7]

            x0_vec = _mm256_sub_ps(x0_vec, c0_vec); // d_cc[i, j] += c_norm[j]
            x1_vec = _mm256_sub_ps(x1_vec, c1_vec); // d_cc[i, j] += c_norm[j]
            x2_vec = _mm256_sub_ps(x2_vec, c2_vec); // d_cc[i, j] += c_norm[j]
            x3_vec = _mm256_sub_ps(x3_vec, c3_vec); // d_cc[i, j] += c_norm[j]
            x4_vec = _mm256_sub_ps(x4_vec, c4_vec); // d_cc[i, j] += c_norm[j]
            x5_vec = _mm256_sub_ps(x5_vec, c5_vec); // d_cc[i, j] += c_norm[j]
            x6_vec = _mm256_sub_ps(x6_vec, c6_vec); // d_cc[i, j] += c_norm[j]
            x7_vec = _mm256_sub_ps(x7_vec, c7_vec); // d_cc[i, j] += c_norm[j]

            r0_vec = _mm256_fmadd_ps(x0_vec, x0_vec, r0_vec);
            r1_vec = _mm256_fmadd_ps(x1_vec, x1_vec, r1_vec);
            r2_vec = _mm256_fmadd_ps(x2_vec, x2_vec, r2_vec);
            r3_vec = _mm256_fmadd_ps(x3_vec, x3_vec, r3_vec);
            r4_vec = _mm256_fmadd_ps(x4_vec, x4_vec, r4_vec);
            r5_vec = _mm256_fmadd_ps(x5_vec, x5_vec, r5_vec);
            r6_vec = _mm256_fmadd_ps(x6_vec, x6_vec, r6_vec);
            r7_vec = _mm256_fmadd_ps(x7_vec, x7_vec, r7_vec);

            x0 += 8;
            x1 += 8;
            x2 += 8;
            x3 += 8;
            x4 += 8;
            x5 += 8;
            x6 += 8;
            x7 += 8;
            
            c0 += 8;
            c1 += 8;
            c2 += 8;
            c3 += 8;
            c4 += 8;
            c5 += 8;
            c6 += 8;
            c7 += 8;
        }


        // x0_vec = _mm256_unpacklo_ps(r0_vec, r1_vec); // [r00, r10, r01, r11, ...]
        // x1_vec = _mm256_unpackhi_ps(r0_vec, r1_vec); // [r02, r12, r03, r13, ...]
        // x2_vec = _mm256_unpacklo_ps(r2_vec, r3_vec);
        // x3_vec = _mm256_unpackhi_ps(r2_vec, r3_vec);
        // x4_vec = _mm256_unpacklo_ps(r4_vec, r5_vec);
        // x5_vec = _mm256_unpackhi_ps(r4_vec, r5_vec);
        // x6_vec = _mm256_unpacklo_ps(r6_vec, r7_vec);
        // x7_vec = _mm256_unpackhi_ps(r6_vec, r7_vec);

        // // 步骤 2: 合并 128-bit 块
        // r0_vec = _mm256_shuffle_ps(x0_vec, x2_vec, _MM_SHUFFLE(1, 0, 1, 0)); // [r00, r10, r20, r30, ...]
        // r1_vec = _mm256_shuffle_ps(x0_vec, x2_vec, _MM_SHUFFLE(3, 2, 3, 2));
        // r2_vec = _mm256_shuffle_ps(x1_vec, x3_vec, _MM_SHUFFLE(1, 0, 1, 0));
        // r3_vec = _mm256_shuffle_ps(x1_vec, x3_vec, _MM_SHUFFLE(3, 2, 3, 2));
        // r4_vec = _mm256_shuffle_ps(x4_vec, x6_vec, _MM_SHUFFLE(1, 0, 1, 0));
        // r5_vec = _mm256_shuffle_ps(x4_vec, x6_vec, _MM_SHUFFLE(3, 2, 3, 2));
        // r6_vec = _mm256_shuffle_ps(x5_vec, x7_vec, _MM_SHUFFLE(1, 0, 1, 0));
        // r7_vec = _mm256_shuffle_ps(x5_vec, x7_vec, _MM_SHUFFLE(3, 2, 3, 2));

        // // 步骤 3: 交换 128-bit 跨通道
        // x0_vec = _mm256_permute2f128_ps(r0_vec, r4_vec, 0x20); // [r00, r10, r20, r30, r40, r50, r60, r70]
        // x1_vec = _mm256_permute2f128_ps(r1_vec, r5_vec, 0x20);
        // x2_vec = _mm256_permute2f128_ps(r2_vec, r6_vec, 0x20);
        // x3_vec = _mm256_permute2f128_ps(r3_vec, r7_vec, 0x20);
        // x4_vec = _mm256_permute2f128_ps(r0_vec, r4_vec, 0x31);
        // x5_vec = _mm256_permute2f128_ps(r1_vec, r5_vec, 0x31);
        // x6_vec = _mm256_permute2f128_ps(r2_vec, r6_vec, 0x31);
        // x7_vec = _mm256_permute2f128_ps(r3_vec, r7_vec, 0x31);

        // x0_vec = _mm256_add_ps(x1_vec, x0_vec);
        // x0_vec = _mm256_add_ps(x2_vec, x0_vec);
        // x0_vec = _mm256_add_ps(x3_vec, x0_vec);
        // x0_vec = _mm256_add_ps(x4_vec, x0_vec);
        // x0_vec = _mm256_add_ps(x5_vec, x0_vec);
        // x0_vec = _mm256_add_ps(x6_vec, x0_vec);
        // x0_vec = _mm256_add_ps(x7_vec, x0_vec);
        // _mm256_storeu_ps(r2 + i, x0_vec);

        r2[i+0] = r0_vec[0] + r0_vec[1] + r0_vec[2] + r0_vec[3] + r0_vec[4] + r0_vec[5] + r0_vec[6] + r0_vec[7];
        r2[i+1] = r1_vec[0] + r1_vec[1] + r1_vec[2] + r1_vec[3] + r1_vec[4] + r1_vec[5] + r1_vec[6] + r1_vec[7];
        r2[i+2] = r2_vec[0] + r2_vec[1] + r2_vec[2] + r2_vec[3] + r2_vec[4] + r2_vec[5] + r2_vec[6] + r2_vec[7];
        r2[i+3] = r3_vec[0] + r3_vec[1] + r3_vec[2] + r3_vec[3] + r3_vec[4] + r3_vec[5] + r3_vec[6] + r3_vec[7];
        r2[i+4] = r4_vec[0] + r4_vec[1] + r4_vec[2] + r4_vec[3] + r4_vec[4] + r4_vec[5] + r4_vec[6] + r4_vec[7];
        r2[i+5] = r5_vec[0] + r5_vec[1] + r5_vec[2] + r5_vec[3] + r5_vec[4] + r5_vec[5] + r5_vec[6] + r5_vec[7];
        r2[i+6] = r6_vec[0] + r6_vec[1] + r6_vec[2] + r6_vec[3] + r6_vec[4] + r6_vec[5] + r6_vec[6] + r6_vec[7];
        r2[i+7] = r7_vec[0] + r7_vec[1] + r7_vec[2] + r7_vec[3] + r7_vec[4] + r7_vec[5] + r7_vec[6] + r7_vec[7];
    }

    float *xi, *ci;
    for (; i < num; i++){
        g0 = y[i];
        r0_vec = _mm256_setzero_ps();
        xi = x + i * dim;
        ci = cen + g0 * dim;
        for (j = 0; j < dim; j+=8){
            x0_vec = _mm256_load_ps(xi);   // x[i, 0].... x[i, 7]
            c0_vec = _mm256_load_ps(ci);   // x[i, 0].... x[i, 7]
            x0_vec = _mm256_sub_ps(x0_vec, c0_vec); // d_cc[i, j] += c_norm[j]
            r0_vec = _mm256_fmadd_ps(x0_vec, x0_vec, r0_vec);
            xi += 8;
            ci += 8;
        }
        r2[i] = r0_vec[0] + r0_vec[1] + r0_vec[2] + r0_vec[3] + r0_vec[4] + r0_vec[5] + r0_vec[6] + r0_vec[7];
    }
}


void compute_c_norm_DCC(
    float *c_norm,       // out, 
    int clu,             // IN, num of clusters, len of c_norm
    float *d_cc,         // OUT, row-major, clu x clu, 
    int ldd,
    int way
    ){
    // c must be aligned, d_cc must be aligned
    
    // D_CC = -2 C * C.T
    // gemm_mul_p(clu, clu, ldc, -2.0f, c, ldc, c, ldc, d_cc, ldd);

    // c_norm, || ci ||^2
    int i;
    if (way != 1){
        for (i = 0; i < clu; i += 8){
            c_norm[i]   = -0.5 * D_CC(i,   i);
            c_norm[i+1] = -0.5 * D_CC(i+1, i+1);
            c_norm[i+2] = -0.5 * D_CC(i+2, i+2);
            c_norm[i+3] = -0.5 * D_CC(i+3, i+3);
            c_norm[i+4] = -0.5 * D_CC(i+4, i+4);
            c_norm[i+5] = -0.5 * D_CC(i+5, i+5);
            c_norm[i+6] = -0.5 * D_CC(i+6, i+6);
            c_norm[i+7] = -0.5 * D_CC(i+7, i+7);
        }
    }

    refine_D_CC(d_cc, ldd, c_norm, clu);
}


void D_NC_initial(float *d_nc, int num, int clu, float *c_norm){
    int i, j;
    __m256 c_norm_vec;

    for (j = 0; j < clu; j += 8){
        c_norm_vec = _mm256_load_ps(c_norm + j);
        for (i = 0; i + 8 <= num; i+=8){
            _mm256_storeu_ps(&D_NC(i+0, j), c_norm_vec);
            _mm256_storeu_ps(&D_NC(i+1, j), c_norm_vec);
            _mm256_storeu_ps(&D_NC(i+2, j), c_norm_vec);
            _mm256_storeu_ps(&D_NC(i+3, j), c_norm_vec);
            _mm256_storeu_ps(&D_NC(i+4, j), c_norm_vec);
            _mm256_storeu_ps(&D_NC(i+5, j), c_norm_vec);
            _mm256_storeu_ps(&D_NC(i+6, j), c_norm_vec);
            _mm256_storeu_ps(&D_NC(i+7, j), c_norm_vec);
        }
        for (;i<num;i++){
            _mm256_storeu_ps(&D_NC(i, j), c_norm_vec);
        }
    }
}


void compute_y(float *d_nc, float *c_norm, 
    int num, int clu, int c_true, int *y, bool update_centers, 
    float *wc, int dim, float *x, float *cen){
    float min_sq_dist, sq_dist;
    int label, i, j, k;

    __m256 x_vec, c_vec;

    for (i = 0; i < num; i++){
        min_sq_dist = d_nc[i * clu] + c_norm[0];
        label = 0;
        for (j = 1; j < c_true; j++){
            sq_dist = d_nc[i * clu + j] + c_norm[j];
            if (sq_dist < min_sq_dist){
                min_sq_dist = sq_dist;
                label = j;
            }
        }

        y[i] = label;

        if (update_centers){
            wc[label] ++;
            for (k = 0; k < dim; k += 8){

                x_vec = _mm256_load_ps(x + i*dim + k);
                c_vec = _mm256_loadu_ps(cen + label*dim + k);
                c_vec = _mm256_add_ps(x_vec, c_vec);

                _mm256_storeu_ps(cen + label * dim + k, c_vec);
            }
        }
        
    }
}

void compute_y_mask(int *mask_ind, float *d_nc, float *c_norm, int num, int clu, int c_true, int *y, bool update_centers, float *wc, int dim, float *x, float *cen){
    float min_sq_dist, sq_dist;

    int label, i, j, k, first_cid, cid;

    __m256 x_vec, c_vec;

    for (i = 0; i < num; i++){
        first_cid = mask_ind[i*clu];
        min_sq_dist = d_nc[i * clu + first_cid] + c_norm[first_cid];
        label = first_cid;
        for (j = 1; j < c_true; j++){
            cid = mask_ind[i*clu + j];
            if (cid == -1 || cid >= c_true){
                break;
            }

            sq_dist = d_nc[i * clu + cid] + c_norm[cid];
            if (sq_dist < min_sq_dist){
                min_sq_dist = sq_dist;
                label = cid;
            }
        }

        y[i] = label;

        if (update_centers){
            wc[label] ++;
            for (k = 0; k < dim; k += 8){

                x_vec = _mm256_load_ps(x + i*dim + k);
                c_vec = _mm256_loadu_ps(cen + label*dim + k);
                c_vec = _mm256_add_ps(x_vec, c_vec);

                _mm256_storeu_ps(cen + label * dim + k, c_vec);
            }
        }
        
    }
}


void merge_wc(float *wc_all, float *wc, int clu){
    int j, k;
    __m256 v0_tmp, v1_tmp, v2_tmp, v3_tmp, v4_tmp, v5_tmp, v6_tmp, v7_tmp;
    __m256 v0_all, v1_all, v2_all, v3_all, v4_all, v5_all, v6_all, v7_all;

    for (j = 0; j + 64 <= clu; j+= 64){
        v0_tmp = _mm256_load_ps(wc + j + 0);
        v1_tmp = _mm256_load_ps(wc + j + 8);
        v2_tmp = _mm256_load_ps(wc + j + 16);
        v3_tmp = _mm256_load_ps(wc + j + 24);
        v4_tmp = _mm256_load_ps(wc + j + 32);
        v5_tmp = _mm256_load_ps(wc + j + 40);
        v6_tmp = _mm256_load_ps(wc + j + 48);
        v7_tmp = _mm256_load_ps(wc + j + 56);

        v0_all = _mm256_load_ps(wc_all + j + 0);
        v1_all = _mm256_load_ps(wc_all + j + 8);
        v2_all = _mm256_load_ps(wc_all + j + 16);
        v3_all = _mm256_load_ps(wc_all + j + 24);
        v4_all = _mm256_load_ps(wc_all + j + 32);
        v5_all = _mm256_load_ps(wc_all + j + 40);
        v6_all = _mm256_load_ps(wc_all + j + 48);
        v7_all = _mm256_load_ps(wc_all + j + 56);

        v0_all = _mm256_add_ps(v0_all, v0_tmp);
        v1_all = _mm256_add_ps(v1_all, v1_tmp);
        v2_all = _mm256_add_ps(v2_all, v2_tmp);
        v3_all = _mm256_add_ps(v3_all, v3_tmp);
        v4_all = _mm256_add_ps(v4_all, v4_tmp);
        v5_all = _mm256_add_ps(v5_all, v5_tmp);
        v6_all = _mm256_add_ps(v6_all, v6_tmp);
        v7_all = _mm256_add_ps(v7_all, v7_tmp);

        _mm256_store_ps(wc_all + j + 0,  v0_all);
        _mm256_store_ps(wc_all + j + 8,  v1_all);
        _mm256_store_ps(wc_all + j + 16, v2_all);
        _mm256_store_ps(wc_all + j + 24, v3_all);
        _mm256_store_ps(wc_all + j + 32, v4_all);
        _mm256_store_ps(wc_all + j + 40, v5_all);
        _mm256_store_ps(wc_all + j + 48, v6_all);
        _mm256_store_ps(wc_all + j + 56, v7_all);
    }
    for (; j < clu; j+=8){
        v0_tmp = _mm256_load_ps(wc + j);
        v0_all = _mm256_load_ps(wc_all + j);
        v0_all = _mm256_add_ps(v0_all, v0_tmp);
        _mm256_store_ps(wc_all + j,  v0_all);
    }
}

void merge_centers(float *cen_all, float *cen, int n_clusters, int dim){

    int j, k;
    __m256 v0_tmp, v1_tmp, v2_tmp, v3_tmp, v4_tmp, v5_tmp, v6_tmp, v7_tmp;
    __m256 v0_all, v1_all, v2_all, v3_all, v4_all, v5_all, v6_all, v7_all;

    for (j = 0; j < n_clusters; j += 8){
        for (k = 0; k < dim; k += 8){
            v0_tmp = _mm256_load_ps(&CEN(j + 0, k));
            v1_tmp = _mm256_load_ps(&CEN(j + 1, k));
            v2_tmp = _mm256_load_ps(&CEN(j + 2, k));
            v3_tmp = _mm256_load_ps(&CEN(j + 3, k));
            v4_tmp = _mm256_load_ps(&CEN(j + 4, k));
            v5_tmp = _mm256_load_ps(&CEN(j + 5, k));
            v6_tmp = _mm256_load_ps(&CEN(j + 6, k));
            v7_tmp = _mm256_load_ps(&CEN(j + 7, k));

            v0_all = _mm256_load_ps(&CEN_ALL(j + 0, k));
            v1_all = _mm256_load_ps(&CEN_ALL(j + 1, k));
            v2_all = _mm256_load_ps(&CEN_ALL(j + 2, k));
            v3_all = _mm256_load_ps(&CEN_ALL(j + 3, k));
            v4_all = _mm256_load_ps(&CEN_ALL(j + 4, k));
            v5_all = _mm256_load_ps(&CEN_ALL(j + 5, k));
            v6_all = _mm256_load_ps(&CEN_ALL(j + 6, k));
            v7_all = _mm256_load_ps(&CEN_ALL(j + 7, k));

            v0_all = _mm256_add_ps(v0_all, v0_tmp);
            v1_all = _mm256_add_ps(v1_all, v1_tmp);
            v2_all = _mm256_add_ps(v2_all, v2_tmp);
            v3_all = _mm256_add_ps(v3_all, v3_tmp);
            v4_all = _mm256_add_ps(v4_all, v4_tmp);
            v5_all = _mm256_add_ps(v5_all, v5_tmp);
            v6_all = _mm256_add_ps(v6_all, v6_tmp);
            v7_all = _mm256_add_ps(v7_all, v7_tmp);

            _mm256_store_ps(&CEN_ALL(j+0, k), v0_all);
            _mm256_store_ps(&CEN_ALL(j+1, k), v1_all);
            _mm256_store_ps(&CEN_ALL(j+2, k), v2_all);
            _mm256_store_ps(&CEN_ALL(j+3, k), v3_all);
            _mm256_store_ps(&CEN_ALL(j+4, k), v4_all);
            _mm256_store_ps(&CEN_ALL(j+5, k), v5_all);
            _mm256_store_ps(&CEN_ALL(j+6, k), v6_all);
            _mm256_store_ps(&CEN_ALL(j+7, k), v7_all);
        }

    }
}


void my_center_shift(const float * cen_old, const float *cen_new, int n_clu, int dim, float *center_shift){

    int j, k;
    __m256 old0, old1, old2, old3, new0, new1, new2, new3, ret0, ret1, ret2, ret3;

    for (j = 0; j < n_clu; j+=4){
        //  || cen_old[j, :] - cen_new[j, :] ||^2
        ret0 = _mm256_setzero_ps();
        ret1 = _mm256_setzero_ps();
        ret2 = _mm256_setzero_ps();
        ret3 = _mm256_setzero_ps();

        for (k = 0; k < dim; k += 8){
            old0 = _mm256_load_ps(cen_old + (j + 0) * dim + k);
            old1 = _mm256_load_ps(cen_old + (j + 1) * dim + k);
            old2 = _mm256_load_ps(cen_old + (j + 2) * dim + k);
            old3 = _mm256_load_ps(cen_old + (j + 3) * dim + k);

            new0 = _mm256_load_ps(cen_new + (j + 0) * dim + k);
            new1 = _mm256_load_ps(cen_new + (j + 1) * dim + k);
            new2 = _mm256_load_ps(cen_new + (j + 2) * dim + k);
            new3 = _mm256_load_ps(cen_new + (j + 3) * dim + k);

            new0 = _mm256_sub_ps(new0, old0);
            new1 = _mm256_sub_ps(new1, old1);
            new2 = _mm256_sub_ps(new2, old2);
            new3 = _mm256_sub_ps(new3, old3);

            ret0 = _mm256_fmadd_ps(new0, new0, ret0);
            ret1 = _mm256_fmadd_ps(new1, new1, ret1);
            ret2 = _mm256_fmadd_ps(new2, new2, ret2);
            ret3 = _mm256_fmadd_ps(new3, new3, ret3);
        }

        center_shift[j+0] = ret0[0] + ret0[1] + ret0[2] + ret0[3] + ret0[4] + ret0[5] + ret0[6] + ret0[7];
        center_shift[j+1] = ret1[0] + ret1[1] + ret1[2] + ret1[3] + ret1[4] + ret1[5] + ret1[6] + ret1[7];
        center_shift[j+2] = ret2[0] + ret2[1] + ret2[2] + ret2[3] + ret2[4] + ret2[5] + ret2[6] + ret2[7];
        center_shift[j+3] = ret3[0] + ret3[1] + ret3[2] + ret3[3] + ret3[4] + ret3[5] + ret3[6] + ret3[7];
    }
}
