import numpy as np
from .srp_hash import SRPHash

# this is unpublished, but similar to Simple-ALSH from the following paper
# Improved Asymmetric Locality Sensitive Hashing (ALSH) for Maximum Inner Product Search (MIPS)
# We call this "asymmetric ball hash" because both the points and the query must lie within 
# the ball ||X|| <= 1

class AsymmetricBallHash(SRPHash):
	def __init__(self, **kwargs):
		super().__init__(**kwargs)

	def _init_projections(self):
		# set up the gaussian random projection vectors
		np.random.seed(self.seed)
		self.W = np.random.normal(size = (self.N_projections,self.d+2))
		self.powersOfTwo = np.array([2**i for i in range(self.p)])

	def hashModel(self, theta):
		theta_tf = np.empty(self.d+2, dtype = theta.dtype)
		theta_tf[:self.d] = theta
		theta_tf[-2] = 0
		theta_tf[-1] = np.sqrt(1 - np.linalg.norm(theta)**2)
		return self._hash(theta_tf)

	def hashData(self, x):
		x_tf = np.empty(self.d+2, dtype = x.dtype)
		x_tf[:self.d] = x
		x_tf[-2] = np.sqrt(1 - np.linalg.norm(x)**2)
		x_tf[-1] = 0
		return self._hash(x_tf)

	def hashModelMulti(self, theta):
		norms = np.linalg.norm(theta,axis = 1)
		norms = np.sqrt(1 - norms**2)
		NData = theta.shape[0]
		theta_tf = np.hstack((theta,np.zeros((NData,1)),norms.reshape((NData,1)) ))
		return self._hashMulti(theta_tf)

	def hashDataMulti(self, X):
		norms = np.linalg.norm(X,axis = 1)
		norms = np.sqrt(1 - norms**2)
		NData = X.shape[0]
		X_tf = np.hstack((X,norms.reshape((NData,1)),np.zeros((NData,1)) ))
		return self._hashMulti(X_tf)
