#include "akmc.h"
#include "utils.h"
#include "fastvar.h"
#include "gemm_mask_float.h"

// 交换两个float指针的函数
void swap_pointers(float** a, float** b) {
    float* temp = *a;
    *a = *b;
    *b = temp;
}

int myakm(int num, int dim_o, int c_true, float *X_o, float *Cen_o, int *labels, long long *count_all, long long *time, int max_iter, bool verbose, float tol, int n_threads){
    // X:   num x dim,  must be rowmajor, continue
    // Cen: clu x dim,  must be rowmajor

    struct timeval time_s, time_e;
    gettimeofday(&time_s, NULL);  // 开始计时

    // align and padding
    int dim = (dim_o + 7) & ~7;
    int clu = (c_true + 7) & ~7;
    float *X = align_padding(X_o, num, dim_o, false);
    float *Cen = align_padding(Cen_o, c_true, dim_o, true);

    // allocation
    float *Cen_new = aligned_alloc(32, clu * dim * sizeof(float));
    float *Cen_threads_all = aligned_alloc(32, n_threads * clu * dim * sizeof(float));

    memset(labels, -1, num * sizeof(int));

    float *wc = aligned_alloc(32, clu * sizeof(float));
    float *wc_threads_all = aligned_alloc(32, n_threads * clu * sizeof(float));

    float *center_shift = aligned_alloc(32, clu * sizeof(float));

    float *x_norm = aligned_alloc(32, num * sizeof(float));
    for (int i = 0; i < num; i++) {
        x_norm[i] = cblas_sdot(dim, &X[i*dim], 1, &X[i*dim], 1);
    }

    float *c_norm = aligned_alloc(32, clu * sizeof(float));
    for (int i = 0; i < clu; i++) {
        c_norm[i] = cblas_sdot(dim, &Cen[i*dim], 1, &Cen[i*dim], 1);
    }

    Fast_var Fv;
    init_fast_var(&Fv, num, dim, clu);

    // printf("Before: OpenBLAS threads = %d\n", openblas_get_num_threads());
    // openblas_set_num_threads(1);
    // printf("After: OpenBLAS threads = %d\n", openblas_get_num_threads());

    int iters = update_chunk_omp(num, dim, clu, c_true, X, x_norm, Cen, c_norm, &Fv, 
        labels, count_all, wc, wc_threads_all, Cen_new, Cen_threads_all, center_shift, time);

    free(X);
    free(Cen);
    free(Cen_new);
    free(Cen_threads_all);
    free(wc);
    free(wc_threads_all);
    free(center_shift);
    free(x_norm);
    free(c_norm);
    free_fast_var(&Fv);


    gettimeofday(&time_e, NULL);  // 开始计时
    time[3] = (time_e.tv_sec - time_s.tv_sec) * 1000000LL + (time_e.tv_usec - time_s.tv_usec);
    return iters;
}

int update_chunk_omp(
    int num, int dim, int clu, int c_true,
    // data
    const float *X,
    float *x_norm,
    float *Cen,
    float *c_norm, 
    Fast_var *Fv,
    // output
    int *labels, 
    long long *count_all,
    float *wc,
    float *wc_new_chunk_all,
    float *Cen_new,
    float *Cen_new_chunk_all,
    float *center_shift,
    long long *time
    ){
    
    int max_iter = 150;
    float tol = 1e-6;

    int *labels_old = malloc(num * sizeof(int));

    long long time_y = 0, time_c = 0, time_v = 0;

    bool strict_convergence = false;
    int iter;

    iter = 0;
    memset(Cen_new, 0, clu * dim * sizeof(float));
    memset(wc, 0, clu * sizeof(float));
    memset(Cen_new_chunk_all, 0, 20 * clu * dim * sizeof (float));
    memset(wc_new_chunk_all, 0, 20 * clu * sizeof(float));

    // update y
    // printf("update y\n");
    time_y += update_y_omp(num, dim, clu, c_true, X, Cen, c_norm, true, Fv,
        labels, wc_new_chunk_all, Cen_new_chunk_all);

    // update Cen
    // printf("update c\n");
    time_c += update_cen(clu, dim,
        wc, wc_new_chunk_all, Cen_new, Cen_new_chunk_all, Cen, center_shift);

    // swap
    // printf("swap c\n");
    swap_pointers(&Cen, &Cen_new);

    /// Fast var (c_norm must)
    // printf("update v\n");
    time_v += update_fv(num, dim, clu, X, x_norm, Cen, c_norm, labels, Fv);

    count_all[iter] = (long long)num * clu;

    // backup
    memcpy(labels_old, labels, num * sizeof(int));

    for (iter = 1; iter < max_iter; iter++){
        /////////  init
        memset(Cen_new, 0, clu * dim * sizeof(float));
        memset(wc, 0, clu * sizeof(float));
        memset(Cen_new_chunk_all, 0, 20 * clu * dim * sizeof (float));
        memset(wc_new_chunk_all, 0, 20 * clu * sizeof(float));

        // printf("update mdy\n");
        time_y += update_mdy(num, dim, clu, c_true, X, Cen, c_norm, true, Fv, labels, wc_new_chunk_all, Cen_new_chunk_all);
        count_all[iter] = 0;
        for (int i = 0; i < num; i++){
            count_all[iter] += Fv->count[i];
        }

        //////////////////// update Cen
        // printf("update cen\n");
        time_c += update_cen(clu, dim,
            wc, wc_new_chunk_all, Cen_new, Cen_new_chunk_all, Cen, center_shift);

        //////////////// swap
        swap_pointers(&Cen, &Cen_new);
        
        ////////////////// fast var
        // printf("update var\n");
        time_v += update_fv(num, dim, clu, X, x_norm, Cen, c_norm, labels, Fv);

        strict_convergence = arr_equal(labels, labels_old, num);
        if (strict_convergence){
            break;
        }else{
            float center_shift_tot = cblas_ssum(c_true, center_shift, 1);
            if (center_shift_tot <= tol){
                break;
            }
        }

        memcpy(labels_old, labels, num * sizeof(int));
    }

    if (! strict_convergence){
        // printf("one more iteration\n");
        time_y += update_mdy(num, dim, clu, c_true, X, Cen, c_norm, true, Fv, labels, wc_new_chunk_all, Cen_new_chunk_all);
    }

    // printf("count all\n");
    // printf("%lld, %lld, %lld\n", count_all[0], count_all[1], count_all[2]);
    time[0] = time_y;
    time[1] = time_c;
    time[2] = time_v;

    free(labels_old);
    // printf("iter = %d\n", iter + 1);
    // printf("time_d = \n");
    // print_arr_double(time_d, 1, 3, 1, 3);
    // printf("time m = %lld, time y = %lld, time c = %lld, time v = %lld\n", time_m, time_y, time_c, time_v);
    return iter + 1;
}

long long update_mdy(int num, int dim, int clu, int c_true, 
    float *X, float *Cen, float *c_norm, bool up_cen, Fast_var *Fv, int *labels,
    float *wc_new_chunk_all, float *Cen_new_chunk_all){
    
    struct timeval time_s, time_e;
    gettimeofday(&time_s, NULL);  // 开始计时

    OMP_CHUNK MC;
    init_chunk_par(&MC, num, 64);

    // OMP_CHUNK DC;
    // init_chunk_par(&DC, num, 64);

    // OMP_CHUNK YC;
    // init_chunk_par(&YC, num, 64);

    int max_threads = omp_get_max_threads();
    int effective_threads = (num < max_threads) ? num : max_threads;

    #pragma omp parallel num_threads(effective_threads)
    {
        int tid = omp_get_thread_num();

        ////////////////////// update mask
        #pragma omp for schedule(dynamic, 1) 
        for (int chunk_idx = 0; chunk_idx < MC.n_chunks; chunk_idx++){
            int start_row = chunk_idx * MC.n_samples_chunk;
            int end_row = start_row + MC.n_samples_chunk;
            if (end_row > num) end_row = num;
            int rows = end_row - start_row;
            if (rows > 0) {
                compute_mask(Fv->COS_OCiCj, Fv->SIN_OCiCj, clu, Fv->cos_ocx + start_row, Fv->sin_ocx + start_row, 
                    rows, Fv->Dcc, clu, Fv->r2 + start_row, labels + start_row, Fv->Mask + start_row * clu, clu, c_true, 
                    Fv->count + start_row);
                gemm_mul_p_mask(rows, clu, dim, -2.0, X + start_row * dim, dim, 
                    Cen, dim, Fv->Dnc + start_row * clu, clu, Fv->Mask + start_row * clu, clu, Fv->count + start_row, 0.0);
                compute_y_mask(Fv->Mask + start_row * clu, Fv->Dnc + start_row * clu, c_norm, rows, clu, c_true, 
                    labels + start_row, up_cen, wc_new_chunk_all + tid * clu, dim, X + start_row * dim, Cen_new_chunk_all + tid * clu * dim);
            }
        }
        // #pragma omp barrier
        // #pragma omp for schedule(dynamic, 1) 
        // for (int chunk_idx = 0; chunk_idx < MC.n_chunks; chunk_idx++){
        //     int start_row = chunk_idx * MC.n_samples_chunk;
        //     int end_row = start_row + MC.n_samples_chunk;
        //     if (end_row > num) end_row = num;
        //     int rows = end_row - start_row;
        //     if (rows > 0) {
        //         // gemm_mul_p(rows, clu, dim, -2.0, X + start_row * dim, dim, Cen, dim, Fv->Dnc + start_row * clu, clu);
        //     }
        // }
    }

    gettimeofday(&time_e, NULL);
    return (time_e.tv_sec - time_s.tv_sec) * 1000000LL + (time_e.tv_usec - time_s.tv_usec);
}

long long update_fv(int num, int dim, int clu, float *X, float *x_norm, float *Cen, float *c_norm, int *labels, Fast_var *Fv){

    struct timeval time_s, time_e;

    // DCC: 4.3, OCICJ 3.4
    OMP_CHUNK CC;
    init_chunk_par(&CC, clu, 256);

    OMP_CHUNK NC;
    init_chunk_par(&NC, num, 256);

    int max_threads = omp_get_max_threads();
    int effective_threads = (CC.n_chunks < max_threads) ? CC.n_chunks : max_threads;
    openblas_set_num_threads(1);

    gettimeofday(&time_s, NULL);
    #pragma omp parallel num_threads(effective_threads)
    {
        #pragma omp for schedule(static) 
        for (int chunk_idx = 0; chunk_idx < CC.n_chunks; chunk_idx++){
            int start_row = chunk_idx * CC.n_samples_chunk;
            int end_row = start_row + CC.n_samples_chunk;
            if (end_row > clu) end_row = clu;
            int rows = end_row - start_row;
            if (rows > 0) {
                cblas_sgemm(CblasRowMajor, CblasNoTrans, CblasTrans, 
                    rows, clu, dim, -2.0, 
                    Cen + start_row * dim, dim, 
                    Cen, dim, 0.0, Fv->Dcc + start_row * clu, clu);
                for (int i = start_row; i < end_row; i++){
                    c_norm[i] = -0.5 * Fv->Dcc[i * clu + i];
                }
            }
        }
        #pragma omp barrier
        // c_norm, Dcc[i, j] = -2 ci cj

        // refine Dcc, OCiCj
        #pragma omp for schedule(static) 
        for (int chunk_idx = 0; chunk_idx < CC.n_chunks; chunk_idx++){
            int start_row = chunk_idx * CC.n_samples_chunk;
            int end_row = start_row + CC.n_samples_chunk;
            if (end_row > clu) end_row = clu;
            int rows = end_row - start_row;
            if (rows > 0) {
                if (rows % 4 != 0) printf("rowsssss = %d\n", rows);

                for (int i = start_row; i < end_row; i++){
                    cblas_saxpy(clu, 1.0f, c_norm, 1, Fv->Dcc + i * clu, 1);
                    // for (int k = 0; k < clu; k++) Fv->Dcc[i * clu + k] += c_norm[i];
                    __m256 s_vec = _mm256_set1_ps(c_norm[i]);
                    for (int k = 0; k < clu; k+=8){
                        __m256 x_vec = _mm256_load_ps(Fv->Dcc + i * clu + k);  // 加载8个元素
                        __m256 result = _mm256_add_ps(x_vec, s_vec);  // 向量化加法
                        _mm256_store_ps(Fv->Dcc + i * clu + k, result);  // 存储结果
                    }
                    Fv->Dcc[i * clu + i] = 0;
                }
                // Dcc[start_row: end_row, :]  =  || ci - cj||^2
                compute_COS_OCiCj_part(c_norm, start_row, rows, clu, 
                    Fv->Dcc, clu, Fv->COS_OCiCj, Fv->SIN_OCiCj, clu);
            }
        }

        // r2
        #pragma omp for schedule(static) 
        for (int chunk_idx = 0; chunk_idx < NC.n_chunks; chunk_idx++){
            int start_row = chunk_idx * NC.n_samples_chunk;
            int end_row = start_row + NC.n_samples_chunk;
            if (end_row > num) end_row = num;
            int rows = end_row - start_row;
            if (rows > 0) {
                compute_r2(X + start_row * dim, rows, dim, Cen, labels + start_row, Fv->r2 + start_row);
                compute_COS_OCX(c_norm, clu, Fv->r2 + start_row, labels + start_row, x_norm + start_row, 
                    Fv->cos_ocx + start_row, Fv->sin_ocx + start_row, rows);
            }
        }

    }

    gettimeofday(&time_e, NULL);
    // void compute_COS_OCiCj(c_norm, clu, d_cc, ldd, COS_OCiCj, SIN_OCiCj, ldo);

    // compute_r2(X, num, dim, Cen, labels, Fv->r2);
    // compute_COS_OCX(c_norm, clu, Fv->r2, labels, x_norm, Fv->cos_ocx, Fv->sin_ocx, num);

    return (time_e.tv_sec - time_s.tv_sec) * 1000000LL + (time_e.tv_usec - time_s.tv_usec);
}

long long update_cen(int clu, int dim,
    float *wc, float *wc_new_chunk_all,
    float *Cen_new, float *Cen_new_chunk_all,
    float *Cen, float *center_shift){

    struct timeval time_s, time_e;
    gettimeofday(&time_s, NULL);
    // for (int p = 0; p < 20; p++) {
    //     cblas_saxpy(clu, 1.0f, wc_new_chunk_all + p * clu, 1, wc, 1);
    // }

    OMP_CHUNK CC;
    int block_size = clu / 20;
    if (block_size == 0) block_size = 1;

    if (block_size % 8 != 0) block_size = (block_size + 7) & ~7;
    init_chunk_par(&CC, clu, block_size);
    #pragma omp parallel num_threads(20)
    {
        // by chunk
        #pragma omp for schedule(static) 
        for (int chunk_idx = 0; chunk_idx < CC.n_chunks; chunk_idx++){
            int start_row = chunk_idx * CC.n_samples_chunk;
            int end_row = start_row + CC.n_samples_chunk;
            if (end_row > clu) end_row = clu;
            int rows = end_row - start_row;
            if (rows > 0) {
                for (int p = 0; p < 20; p++) {
                    array_add_m256(wc_new_chunk_all + p * clu + start_row, wc + start_row, wc + start_row, rows);
                    array_add_m256(Cen_new_chunk_all + p * clu * dim + start_row * dim, Cen_new + start_row * dim, Cen_new + start_row * dim, rows * dim);
                    // cblas_saxpy(rows * dim, 1.0f, Cen_new_chunk_all + p * clu * dim + start_row * dim, 1, Cen_new + start_row * dim, 1);  // C += B
                }

                // average
                for (int j = start_row; j < end_row; j++){
                    double alpha = 1.0 / wc[j];
                    cblas_sscal(dim, alpha, Cen_new + j * dim, 1);  // x = α * x
                }
                my_center_shift(Cen + start_row * dim, Cen_new + start_row * dim, rows, dim, center_shift + start_row);
            }
        }
    }
    // for (int p = 0; p < 20; p++) {
    //     array_add_m256(Cen_new_chunk_all + p * clu * dim, Cen_new, Cen_new, clu * dim);
    //     // cblas_saxpy(clu * dim, 1.0f, Cen_new_chunk_all + p * clu * dim, 1, Cen_new, 1);  // C += B
    // }

    // average_centers(Cen_new, wc, clu, dim);
    // my_center_shift(Cen, Cen_new, clu, dim, center_shift);

    gettimeofday(&time_e, NULL);
    long long time = (time_e.tv_sec - time_s.tv_sec) * 1000000LL + (time_e.tv_usec - time_s.tv_usec);

    // printf("in update center, print Cen new\n");
    // print_arr(Cen_new, clu, dim, 5, 3);

    return time;
}


long long update_y_omp(int num, int dim, int clu, int c_true, 
    float *X, 
    float *Cen, float *c_norm, bool up_cen, Fast_var *Fv,
    int *labels,
    float *wc_new_chunk_all, float *Cen_new_chunk_all
    ){

    struct timeval time_s, time_e;
    gettimeofday(&time_s, NULL);  // 开始计时

    // Dnc
    cblas_sgemm(CblasRowMajor, CblasNoTrans, CblasTrans, num, clu, dim, -2.0, X, dim, Cen, dim, 0.0, Fv->Dnc, clu);  

    int CHUNK_SIZE = 64;
    int n_samples_chunk = (num > CHUNK_SIZE) ? CHUNK_SIZE : num;
    int n_chunks = num / n_samples_chunk;
    if (num != n_chunks * n_samples_chunk){
        n_chunks ++;
    }

    int max_threads = omp_get_max_threads();
    int effective_threads = (n_chunks < max_threads) ? n_chunks : max_threads;

    #pragma omp parallel num_threads(effective_threads)
    {
        int thread_id = omp_get_thread_num();
        #pragma omp for schedule(dynamic, 1) nowait
        for (int chunk_idx = 0; chunk_idx < n_chunks; chunk_idx++){
            int start = chunk_idx * n_samples_chunk;
            int end   = start + n_samples_chunk;
            if (end > num) end = num;

            compute_y(Fv->Dnc + start * clu, c_norm, end-start, clu, c_true, labels + start, up_cen, 
                wc_new_chunk_all + thread_id * clu, dim, X + start * dim, Cen_new_chunk_all + thread_id * clu * dim);
        }
    }
    gettimeofday(&time_e, NULL);
    long long time = (time_e.tv_sec - time_s.tv_sec) * 1000000LL + (time_e.tv_usec - time_s.tv_usec);
    openblas_set_num_threads(20);
    return time;
}

bool arr_equal(int *a, int *b, int n){
    bool eq = true;
    for (int i = 0; i < n; i++){
        if (b[i] != a[i]){
            eq = false;
            break;
        }
    }
    return eq;
}

void average_centers(float *cen, const float *wc, int clu, int dim){
    for (int j = 0; j < clu; j++){
        double alpha = 1.0 / wc[j];
        cblas_sscal(dim, alpha, cen + j * dim, 1);  // x = α * x
    }
}