import numpy as np
from scipy.stats import entropy
from tqdm import tqdm
from itertools import product

is_power_of_2 = lambda n: (n & (n - 1) == 0) and n != 0


def cascade_rounding(values: list | np.ndarray[float]) -> list[int]:
    """https://stackoverflow.com/questions/792460/how-to-round-floats-to-integers-while-preserving-their-sum/792476#792476"""
    float_total = 0
    int_total = 0

    out = []
    for v in values:
        float_total += v
        v_rounded = round(float_total) - int_total
        int_total += v_rounded
        out.append(v_rounded)

    assert np.allclose(float_total, int_total)
    return out


def generate(n: int, alpha: float, z: int) -> np.ndarray[float]:
    return np.random.default_rng().dirichlet(np.zeros(n) + alpha) * z


def write_distribution(
    folder: str, probs: list[int], alpha: float, entropy: float, n: int, z: int
) -> None:
    with open(f"{folder}/{n}_{z}_{alpha}_{entropy:.5f}.dist", "w") as f:
        f.write(f"{z}\n{n} {' '.join(map(str, probs))}\n{entropy}\n{alpha}\n")
    return


if __name__ == "__main__":
    N_DISTS = 30
    N = np.logspace(start=6, stop=9, num=10, base=10).astype(int)
    Z = [1 << 16]

    alphas = np.append(np.geomspace(1e-5, 1, 5), np.logspace(0, 8, base=2, num=5))
    for n, z in tqdm(product(N, Z), desc="Generating", total=int(len(N) * len(Z))):
        H, P, A = [], [], []
        for alpha in alphas:
            for _ in range(1):
                p = generate(n, alpha, z)
                p = cascade_rounding(p)
                assert is_power_of_2(sum(p))
                h = entropy(p, base=2)
                H.append(h)
                P.append(p)
                A.append(alpha)

        H, P, A = np.array(H), np.array(P), np.array(A)

        for p, h, a in zip(P, H, A):
            write_distribution("distributions", p, a, h, n, z)
