import numpy as np
import numpy.typing as npt
from scipy.stats import norm


l2 = lambda x: f'{x:,}=2^{np.log2(x)}' + (f'({x.dtype})' if isinstance(x, np.ndarray) else '')
log2= lambda x: int(np.log2(x))
log2up = lambda x: np.ceil(np.log2(x))
log2down = lambda x: np.floor(np.log2(x))
power2 = lambda x: x * (1 << np.arange(x.shape[1], dtype=np.uint64))

def get_float16_range():
	numbers = np.arange(
		0, 1 << 16, 
		dtype=np.uint16
	)
	numbers = numbers.view(
		np.float32
	).astype(np.float32)
	return numbers[~np.isnan(numbers)]

def F_bimodal(x, mu1=-2.0, sigma1=1.0, 
				mu2=3.0, sigma2=1.0, weight1=0.5
				):
	weight2 = 1.0 - weight1

	dist1 = norm(loc=mu1, scale=sigma1)
	dist2 = norm(loc=mu2, scale=sigma2)

	cdf_result = weight1 * dist1.cdf(x) \
					+ weight2 * dist2.cdf(x)
	return cdf_result


def bitwise_counts(nums):
	assert len(nums) > 0
	# Find the maximum number of bits needed
	max_bits = int(nums.max()).bit_length()
	counts = [0] * max_bits

	for num in nums:
		for i in range(max_bits):
			if (num >> i) & 1:
				counts[i] += 1
	return counts


def compute_size(counts: np.ndarray[int]) -> tuple[int, int, int]:
	
	N  = log2(sum(counts))
	assert 2**N == sum(counts)

	bit_counts = bitwise_counts(counts)
	r_candidates = list(range(1, len(bit_counts)+1))
	cumsum = 0

	for bit, bit_count in enumerate(bit_counts):
		cumsum += 2**bit * bit_count
		condition = lambda e: (cumsum <= 2**(N - e + 1 + bit)) or (e <= bit) 
		r_candidates = list(filter(condition, r_candidates))
	
	lut_size = int(2**N) // int(np.gcd.reduce(counts))
	
	r = max(r_candidates)
	c = N - r

	return r, c, lut_size


def split_in_powers_of_two(counts: npt.NDArray[np.uint64], m: int) -> npt.NDArray[np.uint64]:
	powers = counts[:, None] & (1 << np.arange(m)).astype(np.uint64)
	# assert counts.sum() == powers.sum(), f'{counts.sum()=} ! {powers.sum()=}'
	# Output is of size len(counts) x m, and is binary representation from least to most significant bit (left-to-right)
	return (powers > 0).astype(np.uint64)


def distribute_counts(counts: npt.NDArray[np.uint64], r: int, c: int) -> npt.NDArray[np.uint64]:
	# assert 2**(r + c) == power2(counts).sum(), l2(power2(counts).sum())
	
	for p in range(counts.shape[1]-1, r-1, -1):
		counts[:, r-1] += 2**(p-r+1) * counts[:, p]
	
	counts = counts[:, :r]
	# assert 2**(r + c) == power2(counts).sum(), l2(power2(counts).sum())

	for p in range(r-1, 0, -1):
		if counts[:, p].sum() > 2**c:
			k = np.where(np.cumsum(counts[:, p]) >= (2**c))[0][0]	# First index where cumsum >= 2**M
			exclude = counts[:, p].copy()
			# Keep 2^M counts, set the others to zero
			counts[k:, p] = 0
			counts[k, p] = np.int64(2**c) - counts[:, p].sum()
			assert counts[k, p] <= exclude[k]
			# assert np.all(counts[:, p] <= exclude), f'{np.where(exclude < counts[:, p])=}, {k=}'
			# Distribute the others to the next lower power
			counts[:, p-1] += np.uint64(exclude - counts[:, p])
			counts[:, p-1] += np.uint64(exclude - counts[:, p])
			# assert counts[:, p].sum() == 2**c, f'{p=}: {l2(counts[:, p].sum())=}'
			# assert counts[:, p].max() <= 2**c, f'{p=}: {l2(counts[:, p].max())=}'
		elif counts[:, p].sum() < 2**c:
			# raise ValueError(f'{p=}, {counts[:p].sum()=}, {2**c=}')
			pass

	assert all(counts[:, k].sum() == 2**(c + (k==0)) for k in range(counts.shape[1])), counts.sum(axis=0)
	# assert 2**(r + c) == power2(counts).sum(), l2(power2(counts).sum())
	return counts


def build_clut(values: npt.NDArray, probs: npt.NDArray, precision: int|None) -> tuple[npt.NDArray, npt.NDArray, int, int]:
	""" A probability distribution is given by their values and the respective probabilities. I.e. P(values[i]) = probs[i]. """
	if precision:
		# TRANSFORM PROBABILITIES TO INTEGER-COUNTS:
		assert np.allclose(probs.sum(), np.ones(1)), probs.sum()
		total = 1 << precision	# 2 ** precision
		# PERFORM SUM-PRESERVING ROUNDING:
		raw_counts = probs * total
		floored = np.floor(raw_counts)
		remainder = raw_counts - floored

		# Distribute the remaining counts to minimize error
		deficit = total - np.sum(floored)
		assert np.all(np.mod(deficit, 1) == 0)
		assert np.all(np.mod(floored, 1) == 0)
		deficit = deficit.astype(np.uint64)
		floored = floored.astype(np.uint64)
		# Get indices of the largest remainders
		idxs = np.argsort(-remainder)
		
		for i in range(deficit):
			floored[idxs[i % len(idxs)]] += 1

		while np.all((floored & 1) == 0):
			floored = floored >> 1
		counts = floored
	else:
		assert np.issubdtype(probs.dtype, np.integer)
		assert np.log2(sum(probs)) == log2(sum(probs)), f'Sum of probs {l2(sum(probs)) = } is not a power of 2!'
		counts = probs	
	N = log2(sum(counts))

	r, c, _ = compute_size(counts)
	assert N == r + c

	# DISTRIBUTE COUNTS:
	# counts is 1d, one integer for every value.
	# counts_p2 is 2d, each column belongs to one value.
	# counts_p2[i, j] is equal to the i-th bit of counts[j]. I.e. the j-th column is the binary expansion of counts[j]. 
	counts_p2 = split_in_powers_of_two(
		counts, 
		log2up(counts.max()).astype(np.uint32) + 1
	)

	counts_p2 = distribute_counts(counts_p2, r, c)
	assert 2**(r + c) == power2(counts_p2).sum(), l2(power2(counts_p2).sum())

	clut = np.zeros((r+1, int(2**c)), dtype=values.dtype) - 1
	for p in range(r):
		x = np.repeat(values, counts_p2[:, p].astype('int64'))	# Fill cLUT with values
		assert len(x) == 2**(c + (p == 0)), f'{l2(len(x))=} != {(c + (p == 0))=}'
		
		if p:
			clut[p+1] = x
		else:
			clut[0] = x[:x.size//2]
			clut[1] = x[x.size//2:]
	
	return clut.ravel(), counts, r, c