import numpy as np
from .srp_hash import SRPHash

# This is Simple-ALSH from the following paper
# Improved Asymmetric Locality Sensitive Hashing (ALSH) for Maximum Inner Product Search (MIPS)

class AsymmetricSimpleHash(SRPHash):
	def __init__(self, **kwargs):
		super().__init__(**kwargs)

	def _init_projections(self):
		# set up the gaussian random projection vectors
		np.random.seed(self.seed)
		self.W = np.random.normal(size = (self.N_projections,self.d+1))
		self.powersOfTwo = np.array([2**i for i in range(self.p)])

	def hashModel(self, theta):
		theta_tf = np.empty(self.d+1, dtype = theta.dtype)
		theta_tf[:self.d] = theta
		theta_tf[-1] = 0
		return self._hash(theta_tf)

	def hashData(self, x):
		x_tf = np.empty(self.d+1, dtype = x.dtype)
		x_tf[:self.d] = x
		x_tf[-1] = np.sqrt(1 - np.linalg.norm(x)**2)
		return self._hash(x_tf)

	def hashModelMulti(self, theta):
		NData = theta.shape[0]
		theta_tf = np.hstack((theta,np.zeros((NData,1)) ))
		return self._hashMulti(theta_tf)

	def hashDataMulti(self, X):
		norms = np.linalg.norm(X,axis = 1)
		norms = np.sqrt(1 - norms**2)
		NData = X.shape[0]
		X_tf = np.hstack((X,norms.reshape((NData,1))))
		return self._hashMulti(X_tf)
