#include "utils.h"

void init_chunk_par(OMP_CHUNK *C, int num, int CHUNK_SIZE){
    C->n_samples_chunk = (num > CHUNK_SIZE) ? CHUNK_SIZE : num;
    C->n_chunks = num / C->n_samples_chunk;
    if (num != C->n_chunks * C->n_samples_chunk){
        C->n_chunks ++;
    }
}

void print_arr(float *a, int num, int dim, int pi, int pj){
    for (int i =0; i < min(num, pi); i++){
        for (int j = 0; j < min(dim, pj); j++){
            printf("A[%d, %d] = %f, ", i, j, a[i * dim + j]);
        }
        printf("\n");
    }
}

void print_arr_double(double *a, int num, int dim, int pi, int pj){
    for (int i =0; i < min(num, pi); i++){
        for (int j = 0; j < min(dim, pj); j++){
            printf("A[%d, %d] = %lf, ", i, j, a[i * dim + j]);
        }
        printf("\n");
    }
}

void print_arr_int(int *a, int num, int dim, int pi, int pj){
    for (int i =0; i < min(num, pi); i++){
        for (int j = 0; j < min(dim, pj); j++){
            printf("A[%d, %d] = %d, ", i, j, a[i * dim + j]);
        }
        printf("\n");
    }
}

void array_add_m256(float *a, float *b, float *c, int n) {
    int i;
    int block_size = 16*8;
    for (i = 0; i + block_size <= n; i += block_size) {
        __m256 v0  = _mm256_add_ps(_mm256_load_ps(&a[i + 8 * 0 ]), _mm256_load_ps(&b[i + 8 * 0 ]));
        __m256 v1  = _mm256_add_ps(_mm256_load_ps(&a[i + 8 * 1 ]), _mm256_load_ps(&b[i + 8 * 1 ]));
        __m256 v2  = _mm256_add_ps(_mm256_load_ps(&a[i + 8 * 2 ]), _mm256_load_ps(&b[i + 8 * 2 ]));
        __m256 v3  = _mm256_add_ps(_mm256_load_ps(&a[i + 8 * 3 ]), _mm256_load_ps(&b[i + 8 * 3 ]));
        __m256 v4  = _mm256_add_ps(_mm256_load_ps(&a[i + 8 * 4 ]), _mm256_load_ps(&b[i + 8 * 4 ]));
        __m256 v5  = _mm256_add_ps(_mm256_load_ps(&a[i + 8 * 5 ]), _mm256_load_ps(&b[i + 8 * 5 ]));
        __m256 v6  = _mm256_add_ps(_mm256_load_ps(&a[i + 8 * 6 ]), _mm256_load_ps(&b[i + 8 * 6 ]));
        __m256 v7  = _mm256_add_ps(_mm256_load_ps(&a[i + 8 * 7 ]), _mm256_load_ps(&b[i + 8 * 7 ]));
        __m256 v8  = _mm256_add_ps(_mm256_load_ps(&a[i + 8 * 8 ]), _mm256_load_ps(&b[i + 8 * 8 ]));
        __m256 v9  = _mm256_add_ps(_mm256_load_ps(&a[i + 8 * 9 ]), _mm256_load_ps(&b[i + 8 * 9 ]));
        __m256 v10 = _mm256_add_ps(_mm256_load_ps(&a[i + 8 * 10]), _mm256_load_ps(&b[i + 8 * 10]));
        __m256 v11 = _mm256_add_ps(_mm256_load_ps(&a[i + 8 * 11]), _mm256_load_ps(&b[i + 8 * 11]));
        __m256 v12 = _mm256_add_ps(_mm256_load_ps(&a[i + 8 * 12]), _mm256_load_ps(&b[i + 8 * 12]));
        __m256 v13 = _mm256_add_ps(_mm256_load_ps(&a[i + 8 * 13]), _mm256_load_ps(&b[i + 8 * 13]));
        __m256 v14 = _mm256_add_ps(_mm256_load_ps(&a[i + 8 * 14]), _mm256_load_ps(&b[i + 8 * 14]));
        __m256 v15 = _mm256_add_ps(_mm256_load_ps(&a[i + 8 * 15]), _mm256_load_ps(&b[i + 8 * 15]));
        
        // 存储结果
        _mm256_store_ps(&c[i + 8 * 0], v0);
        _mm256_store_ps(&c[i + 8 * 1], v1);
        _mm256_store_ps(&c[i + 8 * 2], v2);
        _mm256_store_ps(&c[i + 8 * 3], v3);
        _mm256_store_ps(&c[i + 8 * 4], v4);
        _mm256_store_ps(&c[i + 8 * 5], v5);
        _mm256_store_ps(&c[i + 8 * 6], v6);
        _mm256_store_ps(&c[i + 8 * 7], v7);
        _mm256_store_ps(&c[i + 8 * 8], v8);
        _mm256_store_ps(&c[i + 8 * 9], v9);
        _mm256_store_ps(&c[i + 8 * 10], v10);
        _mm256_store_ps(&c[i + 8 * 11], v11);
        _mm256_store_ps(&c[i + 8 * 12], v12);
        _mm256_store_ps(&c[i + 8 * 13], v13);
        _mm256_store_ps(&c[i + 8 * 14], v14);
        _mm256_store_ps(&c[i + 8 * 15], v15);
    }
    for (; i + 8 <= n; i += 8) {
        __m256 v = _mm256_add_ps(_mm256_load_ps(&a[i]), _mm256_load_ps(&b[i]));
        _mm256_store_ps(&c[i], v);
    }
    for (; i < n; i++){
        c[i] = a[i] + b[i];
    }
}

int seg_arr(int *count, int n, int m, int *start_arr, int *end_arr, double *sum_arr){
    double s = 0;
    for (int i = 0; i < n; i++){
        s += (double)count[i];
    }
    double avg = s / m;
    // printf("avg = %lf\n", avg);

    double cur_sum = 0;
    int seg_i = 0;
    int start_idx = 0;
    for (int i = 0; i < n; i++){
        cur_sum += (double)count[i];
        double yu = avg - cur_sum;

        if (yu < 0.01 || i == n-1 || count[i+1] > 2 * yu){
            sum_arr[seg_i] = cur_sum;
            start_arr[seg_i] = start_idx;
            end_arr[seg_i] = i + 1;

            seg_i += 1;
            if (seg_i == m){
                end_arr[seg_i - 1] = n;
                break;
            }

            cur_sum = 0;
            start_idx = i+1;
        }
    }
    return seg_i;
}

bool aligned(float *arr) {
    return (uintptr_t) arr % 32 == 0;
}

float * malloc_align(int n){
    float *ret = aligned_alloc(32, n * sizeof(float));
    return ret;
}

float * align_padding(float *x, int num, int dim, bool row_padding){
    // printf("c: x[0] = %f\n", x[0]);

    int new_col = (dim + 7) & ~7;
    int new_row = num;
    if (row_padding){
        new_row = (num + 7) & ~7;
    }

    float *x_arg = aligned_alloc(32, new_row * new_col * sizeof(float));
    float *ret = x_arg;

    // float pad[8] = {0};

    int i, j, rem;
    float *xi;
    __m256 vec;

    for (i = 0; i < num; i++){
        xi = x + i * dim;
        for (j = 0; j + 8 <= dim; j += 8){

            vec = _mm256_loadu_ps(xi);
            xi += 8;

            _mm256_store_ps(x_arg, vec);
            x_arg += 8;
        }

        if (j < dim){
            rem = dim - j;        
            __m256i mask = _mm256_setr_epi32(
                rem > 0 ? -1 : 0, rem > 1 ? -1 : 0, rem > 2 ? -1 : 0, rem > 3 ? -1 : 0,
                rem > 4 ? -1 : 0, rem > 5 ? -1 : 0, rem > 6 ? -1 : 0, rem > 7 ? -1 : 0
            );

            vec = _mm256_maskload_ps(xi, mask);

            _mm256_store_ps(x_arg, vec);
            x_arg += 8;
        }

    }
    if (num < new_row){
        memset(x_arg, 0, (new_row - num) * new_col * sizeof(float));
    }

    return ret;
}

void free_array(float *arr){
    free(arr);
}