#ifndef FASTVAR_H
#define FASTVAR_H

/* Create macros so that the matrices are stored in row-major order */
// #define A(i, j) a[(i) * lda + (j)]
// #define B(i, j) b[(i) * ldb + (j)]
// #define C(i, j) c[(i) * ldc + (j)]
#define D_CC(i, j) d_cc[(i) * ldd + (j)]
#define COS_OCICJ(i, j) COS_OCiCj[(i) * ldo + (j)]
#define SIN_OCICJ(i, j) SIN_OCiCj[(i) * ldo + (j)]
// #define Mask(i, j) mask[(i) * ldm + (j)]
#define MID(i, j) mask_ind[(i) * ldm + (j)]
#define D_NC(i, j) d_nc[(i) * clu + (j)]
#define CEN(i, j) cen[(i) * dim + (j)]
#define CEN_ALL(i, j) cen_all[(i) * dim + (j)]

// #define min(i, j) ((i) < (j) ? (i) : (j))

#include <immintrin.h> // AVX/AVX2
#include <stdlib.h>
#include <math.h>
#include <stdbool.h>
#include <omp.h>
#include <cblas.h>

typedef struct {
    float *Dcc;      // 输入矩阵A
    float *COS_OCiCj;      // 输入矩阵B
    float *SIN_OCiCj;      // 输出矩阵C
    float *cos_ocx;  // 行范数平方结果
    float *sin_ocx;  // 行范数平方结果
    float *r2;
    float *Dnc;
    int *Mask;
    int *count;
} Fast_var;

void init_fast_var(Fast_var *Fv, int num, int dim, int clu){
    Fv->Dcc = aligned_alloc(32, clu * clu * sizeof(float));
    Fv->COS_OCiCj = aligned_alloc(32, clu * clu * sizeof(float));
    Fv->SIN_OCiCj = aligned_alloc(32, clu * clu * sizeof(float));
    Fv->cos_ocx = aligned_alloc(32, num * sizeof(float));
    Fv->sin_ocx = aligned_alloc(32, num * sizeof(float));
    Fv->r2 = aligned_alloc(32, num * sizeof(float));
    Fv->Dnc = aligned_alloc(32, num * clu * sizeof(float));
    Fv->Mask = aligned_alloc(32, num * clu * sizeof(int));
    Fv->count = aligned_alloc(32, num * sizeof(int));
}

void free_fast_var(Fast_var *Fv){
    free(Fv->Dcc);
    free(Fv->COS_OCiCj);
    free(Fv->SIN_OCiCj);
    free(Fv->cos_ocx);
    free(Fv->sin_ocx);
    free(Fv->r2);
    free(Fv->Dnc);
    free(Fv->Mask);
    free(Fv->count);
}

void atomic_add_array(float* target, float* value, int size);

void compute_mask(float *COS_OCiCj, float *SIN_OCiCj, int ldo, float *cos_ocx, float *sin_ocx, int num, float *d_cc, int ldd, float *r, int *y, 
                  int *mask_ind, int ldm, int max_c, int *count   // mask_ind: ininit by -1
                  );

void compute_COS_OCX(float *c_norm, int clu, float *r, int *y, float *x_norm, float *cos_ocx, float *sin_ocx, int num);

void gemm_myblas(int num, int clu, int dim, float alpha, float *X, float *C, float *D);

void compute_COS_OCiCj_part(float *c_norm, int true_i, int rows, int clu, float *d_cc, int ldd, float *COS_OCiCj, float *SIN_OCiCj, int ldo);

void compute_COS_OCiCj(
    float *c_norm,       // OUT,
    int clu,             // IN, num of clusters, len of c_norm
    float *d_cc,         // OUT, row-major, clu x clu
    int ldd,
    float *COS_OCiCj, float *SIN_OCiCj,
    int ldo
    );

void refine_D_CC(float *d_cc, int ldd, float *c_norm, int clu);
void mask_full(int *mask_ind, int num, int clu);
void compute_r2(float *x, int num, int dim, float *cen, int *y, float *r2);
void compute_r2_omp(int num, int dim, float *x, float *cen, int *y, float *r2);
void compute_c_norm_DCC(float *c_norm, int clu, float *d_cc, int ldd, int way);

void D_NC_initial(float *d_nc, int num, int clu, float *c_norm);

void compute_y(float *d_nc, float *c_norm, int num, int clu, int c_true, int *y, bool update_centers, float *wc, int dim, float *x, float *cen);
void compute_y_mask(int *mask_ind, float *d_nc, float *c_norm, int num, int clu, int c_true, int *y, bool update_centers, float *wc, int dim, float *x, float *cen);

void merge_wc(float *wc_all, float *wc, int clu);
void merge_centers(float *cen_all, float *cen, int n_clusters, int dim);

void my_center_shift(const float * cen_old, const float *cen_new, int n_clu, int dim, float *center_shift);

#endif