import numpy as np
import torch
import pdb

from .arrays import to_np, to_torch

class QuantileDiscretizer:

	def __init__(self, data, N):
		self.data = data
		self.N = N

		n_points_per_bin = int(np.ceil(len(data) / N))
		obs_sorted = np.sort(data, axis=0)
		thresholds = obs_sorted[::n_points_per_bin, :]
		maxs = data.max(axis=0, keepdims=True)

		self.thresholds = np.concatenate([thresholds, maxs], axis=0)
		self.diffs = self.thresholds[1:] - self.thresholds[:-1]
		self._test()

	def __call__(self, x):
		indices = self.discretize(x)
		recon = self.reconstruct(indices)
		error = np.abs(recon - x).max(0)
		return indices, recon, error

	def _test(self):
		print('[ utils/discretization ] Testing...', end=' ', flush=True)
		inds = np.random.randint(0, len(self.data), size=1000)
		X = self.data[inds]
		indices = self.discretize(X)
		recon = self.reconstruct(indices)
		error = np.abs(X - recon).max(0)
		assert (error <= self.diffs.max(axis=0)).all()
		indices_2 = self.discretize(recon)
		assert (indices == indices_2).all()
		print('✓')

	def discretize(self, x, subslice=(None, None)):
		if torch.is_tensor(x):
			x = to_np(x)

		if x.ndim == 1:
			x = x[None]

		start, end = subslice
		thresholds = self.thresholds[:, start:end]

		gt = x[None] >= thresholds[:,None]
		indices = largest_nonzero_index(gt, dim=0)

		if indices.min() < 0 or indices.max() >= self.N:
			indices = np.clip(indices, 0, self.N - 1)

		return indices

	def reconstruct(self, indices, subslice=(None, None)):

		if torch.is_tensor(indices):
			indices = to_np(indices)

		if indices.ndim == 1:
			indices = indices[None]

		if indices.min() < 0 or indices.max() >= self.N:
			print(f'[ utils/discretization ] indices out of range: ({indices.min()}, {indices.max()}) | N: {self.N}')
			indices = np.clip(indices, 0, self.N - 1)

		start, end = subslice
		thresholds = self.thresholds[:, start:end]

		left = np.take_along_axis(thresholds, indices, axis=0)
		right = np.take_along_axis(thresholds, indices + 1, axis=0)
		recon = (left + right) / 2.
		return recon

	#---------------------------- wrappers for planning ----------------------------#

	def expectation(self, probs, subslice):
		'''
			probs : [ B x N ]
		'''

		if torch.is_tensor(probs):
			probs = to_np(probs)

		thresholds = self.thresholds[:, subslice]
		left  = probs @ thresholds[:-1]
		right = probs @ thresholds[1:]

		avg = (left + right) / 2.
		return avg

	def percentile(self, probs, percentile, subslice):
		
		thresholds = self.thresholds[:, subslice]
		cumulative = np.cumsum(probs, axis=-1)
		valid = cumulative > percentile
		inds = np.argmax(np.arange(self.N, 0, -1) * valid, axis=-1)
		left = thresholds[inds-1]
		right = thresholds[inds]
		avg = (left + right) / 2.
		return avg

	#---------------------------- wrappers for planning ----------------------------#

	def value_expectation(self, probs):

		if torch.is_tensor(probs):
			probs = to_np(probs)
			return_torch = True
		else:
			return_torch = False

		probs = probs[:, :, :-1]
		assert probs.shape[-1] == self.N

		rewards = self.expectation(probs[:, 0], subslice=-2)
		next_values = self.expectation(probs[:, 1], subslice=-1)

		if return_torch:
			rewards = to_torch(rewards)
			next_values = to_torch(next_values)

		return rewards, next_values

	def value_fn(self, probs, percentile):
		if percentile == 'mean':
			return self.value_expectation(probs)
		else:
			percentile = float(percentile)

		if torch.is_tensor(probs):
			probs = to_np(probs)
			return_torch = True
		else:
			return_torch = False

		probs = probs[:, :, :-1]
		assert probs.shape[-1] == self.N

		rewards = self.percentile(probs[:, 0], percentile, subslice=-2)
		next_values = self.percentile(probs[:, 1], percentile, subslice=-1)

		if return_torch:
			rewards = to_torch(rewards)
			next_values = to_torch(next_values)

		return rewards, next_values

def largest_nonzero_index(x, dim):
	N = x.shape[dim]
	arange = np.arange(N) + 1

	for i in range(dim):
		arange = np.expand_dims(arange, axis=0)
	for i in range(dim+1, x.ndim):
		arange = np.expand_dims(arange, axis=-1)

	inds = np.argmax(x * arange, axis=0)
	lt_mask = (~x).all(axis=0)
	gt_mask = (x).all(axis=0)

	inds[lt_mask] = 0
	inds[gt_mask] = N

	return inds
