#include "nLUT.h"
#include "cLUT.h"

#include <stdlib.h>
#include <stdbool.h>

#include <stdio.h>

#include <flip.h>

static uint64_t count_bits(uint64_t number)
{
    return __builtin_clzll(0) - __builtin_clzll(number);
}

nlut_context_t *build_nLUT(uint64_t *repeats, size_t N)
{
    nlut_context_t *ctx = malloc(sizeof(nlut_context_t));
    ctx->nLUT_size = 0;
    for (size_t i = 0; i < N; i++) {
        ctx->nLUT_size += repeats[i];
    }
    
    ctx->nLUT = malloc(ctx->nLUT_size * sizeof(*ctx->nLUT));
    if (ctx->nLUT == NULL) {
        fprintf(stderr, "malloc failed on %zu bytes\n", ctx->nLUT_size * sizeof(*ctx->nLUT));
        exit(2);
    }
    size_t insert_pos = 0;
    for (size_t i = 0; i < N; i++) {
        for (uint64_t j = 0; j < repeats[i]; j++) {
            ctx->nLUT[insert_pos++] = i + 1; // one-indexed :P
        }
    }

    bool is_power_of_two = (ctx->nLUT_size & (ctx->nLUT_size - 1)) == 0;
    if (is_power_of_two) {
        ctx->flip_func = flip_n;
        ctx->n_or_nbits = count_bits((uint64_t)ctx->nLUT_size) - 1;
    } else {
        ctx->flip_func = uniform;
        ctx->n_or_nbits = ctx->nLUT_size;
    }

    return ctx;
}

uint64_t nLUT_sampling(nlut_context_t *ctx)
{
    size_t idx = (ctx->flip_func)(ctx->n_or_nbits);
	return ctx->nLUT[idx];
}

void free_nLUT(nlut_context_t *ctx)
{
    free(ctx->nLUT);
    free(ctx);
}
