/*
 * If you are not too familiar with C, feel free you check out
 * pylut.py in the same folder for reference
 * 
 */

#include "cLUT.h"

#include <stdlib.h>
#include <strings.h>
#include <stdbool.h>
#include <stdint.h>
#include <assert.h>

#include <flip.h>

#define MAX_BIT_COUNT 64

static uint64_t sample_exp(uint64_t exp_bits) {
    for (uint64_t i = exp_bits; i >= 1; i--) {
        if (flip() == 1) {
            return i;
        }
    }
    return 0;
}

static uint64_t count_bits(uint64_t number)
{
    return __builtin_clzll(0) - __builtin_clzll(number);
}

static uint64_t sum_max_array(uint64_t *numbers, size_t num_numbers, uint64_t *max_val)
{
    uint64_t result = 0;
    *max_val = 0;
    for (size_t i = 0; i < num_numbers; i++) {
        if (*max_val < numbers[i])
            *max_val = numbers[i];

        result += numbers[i];
    }

    return result;
}

static uint64_t sum_array(uint64_t *numbers, size_t num_numbers)
{
    uint64_t result = 0;
    for (size_t i = 0; i < num_numbers; i++) {
        result += numbers[i];
    }

    return result;
}

static void np_repeat_plain(uint64_t *repeats, size_t N, uint64_t *target)
{
    // We will write over your data if you supply the wrong number of repeats >:-|
    size_t insert_index = 0;
    for (size_t i = 0; i < N; i++) {
        for (uint64_t j = 0; j < repeats[i]; j++) {
            target[insert_index++] = i + 1; // one-indexed :P
        }
    }
}

static uint64_t *bitwise_count(uint64_t *numbers, size_t num_numbers, uint64_t max_val, size_t *result_count)
{
    // maxbits
    *result_count = count_bits(max_val);

    uint64_t *result = calloc(*result_count, sizeof(*result));

    for (size_t n = 0; n < num_numbers; n++) {
        for (size_t j = 0; j < *result_count; j++) {
            if (numbers[n] & ((uint64_t)1 << j))
                result[j]++;
        }
    }
    return result;
}

static void compute_r_c_maxBits(uint64_t *counts /* a.k.a. probs */, size_t counts_size, size_t *r_out, size_t *c_out, uint64_t *max_bits)
{
    size_t bit_counts_len; // len(bit_counts)
    uint64_t max_val;
    uint64_t counts_SUM = sum_max_array(counts, counts_size, &max_val);
    assert(counts_SUM != 0);
    uint64_t *bit_counts = bitwise_count(counts, counts_size, max_val, &bit_counts_len);

    *max_bits = count_bits(counts_SUM); // b + 1
    uint64_t b = *max_bits - 1;

    /* Find R the naive way, optimization needed */
    // We filter out possible candidates, r_randidates[e-1] == true means e is no candidate any more
    // the highest false index is the highest max value
    bool *r_candidates = calloc(b+1, sizeof(bool)); // rejected candidates in (0,1,...,b)
    uint64_t cumsum = 0;
    for (uint64_t bit = 0; bit <= b; bit++) {
        cumsum += (1ull << bit) * bit_counts[bit];
        for (uint64_t e = 0; e <= b; e++) { // for e in [0,b[
            if (r_candidates[e])
                continue;

            r_candidates[e] = !((cumsum <= (1ull << (b - e + bit + 1))) || !(bit <= e - 1));
        }
    }

    for (ssize_t r = (ssize_t)b; r >= 0; r--) {
        if (!r_candidates[r]) {
            *r_out = (size_t)r;
            break;
        }
    }

    *c_out = b - *r_out;

    free(r_candidates);
    free(bit_counts);
}

static uint64_t *distribute(uint64_t *counts, size_t counts_size, size_t max_bits, size_t r, size_t c)
{
    uint64_t *counts_p2 = malloc(max_bits * counts_size * sizeof(*counts_p2));
    #define counts_p2_at(x, y) (counts_p2[((y) * counts_size) + (x)])

    for (size_t i = 0; i < counts_size; i++) {
        for (size_t j = 0; j < max_bits; j++) {
            counts_p2_at(i, j) = (counts[i] >> j) & 1;
        }
    }

    for (size_t p = max_bits - 1; p >= r; p--) {
        for (size_t b = 0; b < counts_size; b++) {
            counts_p2_at(b, r - 1) += (1 << (p - r + 1)) * counts_p2_at(b, p);
            counts_p2_at(b, p) = 0;
        }
    }

    for (size_t p = r - 1; p > 0; p--) {
        if (sum_array(counts_p2 + p * counts_size, counts_size) > (1llu << c)) {
            uint64_t cumsum = 0;
            for (size_t b = 0; b < counts_size; b++) {
                cumsum += counts_p2_at(b, p);
                if (cumsum > (1llu << c)) {
                    uint64_t move = min(cumsum - (1llu << c), counts_p2_at(b, p));
                    counts_p2_at(b, p)     -= move;
                    counts_p2_at(b, p - 1) += 2 * move;
                }
            }
        }
    }

    #undef counts_p2_at
    return counts_p2;
}

clut_context_t *build_cLUT(uint64_t *counts, size_t counts_size)
{
    clut_context_t *ctx = malloc(sizeof(clut_context_t));
    uint64_t max_bits;
    compute_r_c_maxBits(counts, counts_size, &ctx->r, &ctx->c, &max_bits);
    
    uint64_t *counts_p2 = distribute(counts, counts_size, max_bits, ctx->r, ctx->c);
    ctx->cLUT_linelen = (1llu << ctx->c);
    ctx->cLUT_size = (ctx->r + 1) * ctx->cLUT_linelen;
    ctx->cLUT = malloc(sizeof(*ctx->cLUT) * ctx->cLUT_size);

    for (size_t p = 0; p < ctx->r; p++) {      
        uint64_t *insert_pos;
        if (p == 0) {
            insert_pos = &ctx->cLUT[0];
        } else {
            insert_pos = &ctx->cLUT[(p + 1) * ctx->cLUT_linelen];
        }
        np_repeat_plain(&counts_p2[p * counts_size], counts_size, insert_pos);
    }

    free(counts_p2);
    return ctx;
}

uint64_t cLUT_sampling(clut_context_t *ctx)
{
    int R = sample_exp(ctx->r);
	int C = flip_n(ctx->c);
	size_t idx = (R * ctx->cLUT_linelen) + C;
	return ctx->cLUT[idx];
}

void free_cLUT(clut_context_t *ctx)
{
    free(ctx->cLUT);
    free(ctx);
}
