import numpy as np 
from .srp_hash import SRPHash

# this is Sign-ALSH from the following paper
# Improved Asymmetric Locality Sensitive Hashing (ALSH) for Maximum Inner Product Search (MIPS)
# we call this "asymmetric series hash" because it relies on the convergence of a series sum to 
# approximate the inner product

class AsymmetricSeriesHash(SRPHash):
	def __init__(self, m, **kwargs):
		self.m = m
		super().__init__(**kwargs)

	def _init_projections(self):
		# set up the gaussian random projection vectors
		np.random.seed(self.seed)
		self.W = np.random.normal(size = (self.N*self.p,self.d + self.m))
		self.powersOfTwo = np.array([2**i for i in range(self.p)])

	def hashModel(self, theta):
		theta_tf = np.zeros(self.d+self.m, dtype = theta.dtype)
		theta_tf[:self.d] = theta
		return self._hash(theta_tf)

	def hashData(self, x):
		x_tf = np.empty(self.d+self.m, dtype = x.dtype)
		norm = np.linalg.norm(x)
		x_tf[:self.d] = x
		for i in range(self.m):
			power = 2**(i+1)
			x[self.d+i] = 0.5 - norm**power
		return self._hash(x_tf)

	def hashModelMulti(self, theta):
		NData = theta.shape[0]
		theta_tf = np.zeros((NData, self.d+self.m), dtype = theta.dtype)
		theta_tf[:,:-self.m] = theta
		return self._hashMulti(theta_tf)

	def hashDataMulti(self, X):
		NData = X.shape[0]
		norms = np.linalg.norm(X,axis = 1)
		X_tf = np.zeros((NData, self.d+self.m), dtype = X.dtype)
		X_tf[:,:-self.m] = X 
		for i in range(self.m): 
			power = 2**(i+1)
			X_tf[:,self.d+i] = 0.5 - norms**power
		return self._hashMulti(X_tf)
