/*
  Name:     flip.c
  Purpose:  Generating a sequence of pseudo-random bits.
  Author:   CMU Probabilistic Computing Systems Lab
            And An Anonymous Lab
  Copyright (C) 2025 CMU Probabilistic Computing Systems Lab, All Rights Reserved.

  Released under Apache 2.0; refer to LICENSE.txt
*/

#include "flip.h"

#include <stdio.h>
#include <stdlib.h>
#include <stdint.h>
#include <limits.h>

#if __linux__
#  include <sys/random.h>
#else

static ssize_t get_dev_urandom(void *buf, size_t buflen, unsigned int flags /* ignored */)
{
    (void)flags;
    static FILE *fp;
    if (fp == NULL) {
            fp =fopen("/dev/urandom", "rb");
    }
    if (fp == NULL) {
        fprintf(stderr, "%s:%d: Cannot get a source of randomness on your system. Aborting.\n", __FILE__, __LINE__);
        exit(1);
    }
    return fread(buf, 1, buflen, fp);
}

#define getrandom get_dev_urandom

#endif

// assume RAND_MAX is a Mersenne number
size_t flip_k = (CHAR_BIT * sizeof(flip_k)) - __builtin_clzll(RAND_MAX);
flip_t flip_word = 0;
size_t flip_pos = 0;
uint64_t NUM_RNG_CALLS = 0;

void check_refill(void) {
    if (flip_pos == 0) {
        ++NUM_RNG_CALLS;
        // we set flip_k to 32 to use sysrandom
        if (flip_k == 32) {
            getrandom(&flip_word, sizeof(flip_word), 0);
        } else {
            flip_word = rand();
        }
        flip_pos = flip_k;
    }
}

flip_t flip(void) {
    check_refill();
    --flip_pos;
    flip_t result = (flip_word >> flip_pos) & 1;
    return result;
}

flip_t flip_n(flip_t n) {
    flip_t result = 0;
    flip_t bits_remaining = n;

    while (bits_remaining > 0) {
        check_refill();
        flip_t num_bits_extract = min(bits_remaining, flip_pos);
        flip_pos -= num_bits_extract;
        flip_t b = (flip_word >> flip_pos) & (FLIP_T_MAX >> ((sizeof(flip_t) * CHAR_BIT) - num_bits_extract));

        result = (result << num_bits_extract) | b;
        bits_remaining -= num_bits_extract;
    }

    return result;
}

flip_t uniform(flip_t n) {
    flip_t num_bits_presample = (CHAR_BIT * sizeof(flip_t)) - __builtin_clzll(n - 1);
    flip_t bound = (flip_t)1 << num_bits_presample;
    flip_t x = flip_n(num_bits_presample);
    for (;;) {
        if (bound >= n) {
            if (x < n) { return x; }
            bound -= n;
            x -= n;
        }
        bound <<= 1;
        x = (x << 1) | flip();
    }
}

flip_t bernoulli(flip_t numer, flip_t denom) {
    if (numer == 0) {
        return 0;
    }
    if (numer == denom) {
        return 1;
    }
    flip_t y;

    for (;;) {
        numer <<= 1;
        if (numer == denom) {
            return flip();
        }
        if ((y = numer > denom)) {
            numer -= denom;
        }
        if (flip()) {
            return y;
        }
    }
}
