#!/usr/bin/env python

# THIS FILE IS NOT USED FOR OUR SUBMISSION BUT SERVES AS REFERENCE
# OF THE C CODE IN CASE YOU ARE NOT TOO FAMILIAR WITH THE C LANGUAGE

import numpy as np
import numpy.typing as npt


l2 = lambda x: f'{x:,}=2^{np.log2(x)}' + (f'({x.dtype})' if isinstance(x, np.ndarray) else '')
log2up = lambda x: np.ceil(np.log2(x))
log2down = lambda x: np.floor(np.log2(x))
power2 = lambda x: x * (1 << np.arange(x.shape[1], dtype=np.uint64))


def bitwise_counts(nums):
	assert len(nums) > 0
	# Find the maximum number of bits needed
	max_bits = int(nums.max()).bit_length()
	counts = [0] * max_bits

	for num in nums:
		for i in range(max_bits):
			if (num >> i) & 1:
				counts[i] += 1
	return counts


def compute_size(probs: np.ndarray, precision: int) -> tuple[int, int, int, np.ndarray]:
	# assert np.allclose(probs.sum(), np.ones(1)), probs.sum()
	total = 1 << precision	# 2 ** precision
	raw_counts = probs * total
	floored = np.floor(raw_counts)
	remainder = raw_counts - floored

	print(f'{l2(sum(raw_counts))=}; {sum(remainder)=}')

	# Distribute the remaining counts to minimize error
	deficit = total - np.sum(floored)
	assert np.all(np.mod(deficit, 1) == 0)
	assert np.all(np.mod(floored, 1) == 0)
	deficit = deficit.astype(np.uint64)
	floored = floored.astype(np.uint64)
	# Get indices of the largest remainders
	idxs = np.argsort(-remainder)
	"""
	for i in range(deficit):
		floored[idxs[i % len(idxs)]] += 1

	while np.all((floored & 1) == 0):
		floored = floored >> 1
	counts = floored"""
	counts = probs
	
	bit_counts = bitwise_counts(counts)

	N = int(np.log2(sum(counts)))
	# assert 2**N == sum(counts), f'{l2(sum(counts))=} != {N=}'

	r_candidates = list(range(1, len(bit_counts)+1))
	cumsum = 0

	for bit, bit_count in enumerate(bit_counts):
		cumsum += 2**bit * bit_count
		condition = lambda e: (cumsum <= 2**(N - e + 1 + bit)) or (e <= bit) 
		r_candidates = list(filter(condition, r_candidates))
	
	lut_size = int(2**N) // int(np.gcd.reduce(counts))
	
	r = max(r_candidates)
	c = N - r

	return r, c, lut_size, counts


def split_in_powers_of_two(counts: npt.NDArray[np.uint64], m: int) -> npt.NDArray[np.uint64]:
	powers = counts[:, None] & (1 << np.arange(m)).astype(np.uint64)
	assert counts.sum() == powers.sum(), f'{counts.sum()=} ! {powers.sum()=}'
	# Output is of size len(counts) x m, and is binary representation from least to most significant bit (left-to-right)
	return (powers > 0).astype(np.uint64)


def distribute_counts(counts: npt.NDArray[np.uint64], r: int, c: int) -> npt.NDArray[np.uint64]:
	assert 2**(r + c) == power2(counts).sum(), l2(power2(counts).sum())
	
	for p in range(counts.shape[1]-1, r-1, -1):
		counts[:, r-1] += 2**(p-r+1) * counts[:, p]
	
	counts = counts[:, :r]
	assert 2**(r + c) == power2(counts).sum(), l2(power2(counts).sum())

	for p in range(r-1, 0, -1):
		if counts[:, p].sum() > 2**c:
			k = np.where(np.cumsum(counts[:, p]) >= (2**c))[0][0]	# First index where cumsum >= 2**M
			exclude = counts[:, p].copy()
			# Keep 2^M counts, set the others to zero
			counts[k:, p] = 0
			counts[k, p] = np.int64(2**c) - counts[:, p].sum()
			assert counts[k, p] <= exclude[k]
			assert np.all(counts[:, p] <= exclude), f'{np.where(exclude < counts[:, p])=}, {k=}'
			# Distribute the others to the next lower power
			counts[:, p-1] += np.uint64(exclude - counts[:, p])
			counts[:, p-1] += np.uint64(exclude - counts[:, p])
			assert counts[:, p].sum() == 2**c, f'{p=}: {l2(counts[:, p].sum())=}'
			assert counts[:, p].max() <= 2**c, f'{p=}: {l2(counts[:, p].max())=}'
		elif counts[:, p].sum() < 2**c:
			raise ValueError(f'{p=}, {counts[:p].sum()=}, {2**c=}')

	assert all(counts[:, k].sum() == 2**(c + (k==0)) for k in range(counts.shape[1])), counts.sum(axis=0)
	assert 2**(r + c) == power2(counts).sum(), l2(power2(counts).sum())
	return counts


def build_clut(values: npt.NDArray, probs: npt.NDArray, precision: int) -> tuple[npt.NDArray, npt.NDArray]:
	""" A probability distribution is given by their values and the respective probabilities. I.e. P(values[i]) = probs[i]. """
	# Here, counts = round(probs * 2**precision).
	r, c, lut_size, counts = compute_size(probs, precision)
	assert sum(counts) == 2**precision, f'{l2(sum(counts))=} != {precision=}'
	print("r, c, presicion, lut_size = ", r, c, precision, lut_size)
	assert precision == r + c
	
	# counts is 1d, one integer for every value.
	# counts_p2 is 2d, each column belongs to one value.
	# counts_p2[i, j] is equal to the i-th bit of counts[j]. I.e. the j-th column is the binary expansion of counts[j]. 
	counts_p2 = split_in_powers_of_two(counts, np.ceil(np.log2(counts)).astype(np.uint32).max() + 1)
	assert 2**(r + c) == power2(counts_p2).sum(), l2(power2(counts_p2).sum())

	counts_p2 = distribute_counts(counts_p2, r, c)
	assert 2**(r + c) == power2(counts_p2).sum(), l2(power2(counts_p2).sum())

	clut = np.zeros((r+1, int(2**c)), dtype=values.dtype) - 1
	for p in range(r):
		x = np.repeat(values, counts_p2[:, p].astype(np.int64))	# Fill cLUT with values
		assert len(x) == 2**(c + (p == 0)), f'{l2(len(x))=} != {(c + (p == 0))=}'
		
		if p:
			clut[p+1] = x
		else:
			clut[0] = x[:x.size//2]
			clut[1] = x[x.size//2:]

	assert np.all(clut != -1)
	return clut, counts


def test():
	# Create dummy distribution. We could use the distributions from aldr, read the second line as counts, the first line as total, and then probs=counts/total.
	# Even better would be to just load counts, and give counts to build_clut().
	values = np.arange(100)
	counts = np.random.randint(0, 100, 100)
	counts[-1] = 2**(np.ceil(np.log2(counts.sum()))) - counts.sum()
	precision = int(np.log2(counts.sum())) + 1
	print(precision)
	assert sum(counts) <= 2**precision
	probs = np.array(counts, dtype=np.float32) / sum(counts)

	# Build cLUT
	clut, lut_counts = build_clut(values, probs, precision)

	clut_counts = []
	for value in values:
		clut_counts += [power2(clut[1:].T == value).sum() + (clut[0] == value).sum()]
	
	assert np.allclose(lut_counts, np.array(clut_counts, dtype=np.uint32)), f'{counts=}\n{clut_counts=}'

	# Build nLUT
	nlut = np.repeat(values, clut_counts)


def main():
	import sys
	if len(sys.argv) < 2:
		print("Usage: %s inputfile", sys.argv[0])
		sys.exit(1)
	l = open(sys.argv[1]).readlines()
	probs = np.array([int(float(e.strip())) for e in l])
	values = np.arange(probs.size)

	precision = int(np.log2(probs.sum()))
	print(sum(probs))

	clut, lut_counts = build_clut(values, probs, precision)


if __name__ == "__main__":
	main()