/*
  Name:     aldr.c
  Purpose:  Fast sampling of random integers.
  Author:   CMU Probabilistic Computing Systems Lab
            And An Anonymous Lab
  Copyright (C) 2025 CMU Probabilistic Computing Systems Lab, All Rights Reserved.

  Released under Apache 2.0; refer to LICENSE.txt
*/

#include <stdint.h>
#include <stdlib.h>

#include "aldr.h"
#include "flip.h"

void free_aldr_flat_s (struct aldr_flat_s *x) {
    free(x->breadths);
    free(x->leaves_flat);
    free(x);
}

void free_array_s (struct array_s *x) {
    free(x->a);
    free(x);
}

#ifndef __SIZEOF_INT128__
#    error "I need a native 128bit type"
#endif

static int popcount_u128(__uint128_t x) {
#if (defined(__GNUC__) && __GNUC__ >= 14)
    // requires GCC 14+ or Clang 9+
    return __builtin_popcountg(x);
#else
    return __builtin_popcountll((uint64_t)(x >> 64)) + 
           __builtin_popcountll((uint64_t)(x      ));
#endif
}

struct aldr_flat_s *preprocess_aldr_flat_k(uint64_t* a, size_t n, uint64_t kmul) {
    uint64_t m = 0;
    for (size_t i = 0; i < n; ++i) {
        m += a[i];
    }
    uint64_t k = __builtin_clzll(0) - __builtin_clzll(m) - (1 == __builtin_popcountll(m));    
    uint64_t K = k * kmul;               // depth

    __uint128_t c = ((__uint128_t)1 << (__uint128_t)K) / (__uint128_t)m;
    __uint128_t r = ((__uint128_t)1 << (__uint128_t)K) % (__uint128_t)m;

    uint64_t num_leaves = popcount_u128(r);
    for (size_t i = 0; i < n; ++i) {
        num_leaves += popcount_u128((__uint128_t)c * (__uint128_t)a[i]);
    }

    uint64_t *breadths = calloc(K + 1, sizeof(*breadths));
    uint64_t *leaves_flat = calloc(num_leaves, sizeof(*leaves_flat));

    uint64_t location = 0;
    for(uint64_t j = 0; j <= K; j++) {
        __uint128_t bit = ((__uint128_t)1 << (__uint128_t)(K - j));
        if (r & bit) {
            leaves_flat[location] = 0;
            ++breadths[j];
            ++location;
        }
        for (uint64_t i = 0; i < n; ++i) {
            __uint128_t Qi = c * (__uint128_t)a[i];
            if (Qi & bit) {
                leaves_flat[location] = i + 1;
                ++breadths[j];
                ++location;
            }
        }
    }

    struct aldr_flat_s *s = malloc(sizeof(struct aldr_flat_s));
    s->length_breadths = K + 1;
    s->length_leaves_flat = num_leaves;
    s->breadths = breadths;
    s->leaves_flat = leaves_flat;

    return s;
}

struct aldr_flat_s *preprocess_aldr_flat(uint64_t* a, size_t n) {
    return preprocess_aldr_flat_k(a, n, 2);
}

struct aldr_flat_s *preprocess_fldr_flat(uint64_t* a, size_t n) {
    return preprocess_aldr_flat_k(a, n, 1);
}

struct array_s *preprocess_aldr_enc_k(uint64_t* a, size_t n, uint64_t kmul) {
    uint64_t m = 0;
    for (size_t i = 0; i < n; ++i) {
        m += a[i];
    }
    uint64_t k = __builtin_clzll(0) - __builtin_clzll(m) - (1 == __builtin_popcountll(m));
    uint64_t K = k * kmul;

    uint64_t c = ((__uint128_t)1 << (__uint128_t)K) / (__uint128_t)m;
    uint64_t r = ((__uint128_t)1 << (__uint128_t)K) % (__uint128_t)m;

    // flattened but 50% sparse encoding
    uint64_t num_leaves = __builtin_popcountll(r);
    for (size_t i = 0; i < n; ++i) {
        num_leaves += __builtin_popcountll(c * a[i]);
    }
    uint64_t *enc = calloc((num_leaves << 1) - 1, sizeof(*enc));
    uint64_t prev_length = 1;
    uint64_t location = 0;
    for(uint64_t j = 0; j <= K; j++) {
        uint64_t bit = (1llu << (K - j));
        uint64_t next_length = prev_length;
        if (r & bit) {
            // reject outcome: flip and go to child of root
            enc[location++] = 1;
        }
        if (r & bit) {
            enc[location++] = ~0;
        }
        for (size_t i = 0; i < n; ++i) {
            if ((c*a[i]) & bit) {
                enc[location++] = ~(i+1);
            }
        }
        for ( ; location < prev_length; ++location) {
            enc[location] = next_length;
            next_length += 2;
        }
        prev_length = next_length;
    }

    struct array_s *s = malloc(sizeof(*s));
    s->length = (num_leaves << 1) - 1;
    s->a = enc;
    return s;
}

struct array_s *preprocess_aldr_enc(uint64_t* a, size_t n) {
    return preprocess_aldr_enc_k(a, n, 2);
}

struct array_s *preprocess_fldr_enc(uint64_t* a, size_t n) {
    return preprocess_aldr_enc_k(a, n, 1);
}

uint64_t sample_aldr_flat(struct aldr_flat_s* f) {
    while (1) {
        uint64_t depth = 0;
        uint64_t location = 0;
        uint64_t val = 0;
        for (;;) {
            if (val < f->breadths[depth]) {
                uint64_t ans = f->leaves_flat[location + val];
                if (ans) return ans - 1;
                else break;
            }
            location += f->breadths[depth];
            val = ((val - f->breadths[depth]) << 1) | flip();
            ++depth;
        }
    }
}

uint64_t sample_aldr_enc(struct array_s* x) {
    int c = x->a[0];
    for (;;) {
        if (c < 0) {
            // note that this implementation uses all pointers
            // instead of "n" with special reject logic
            return (~c) - 1;
        }
        c = x->a[c+flip()];
    }
}

uint64_t bytes_sample_aldr_flat(struct aldr_flat_s *x) {
    // this doesn't count the length variables themselves
    // because we don't need them and just added them here
    // for easy byte counting
    return
        x->length_breadths * sizeof(x->breadths[0])
            + x->length_leaves_flat * sizeof(x->leaves_flat[0]);
}

uint64_t bytes_array(struct array_s *x) {
    return x->length * sizeof(x->a[0]) + sizeof(x->length);
}
