import numpy as np
from .abstract_hash import AbstractHash

class SRPHash(AbstractHash):
	def __init__(self, p:int, **kwargs):
		self.p = p
		super().__init__(**kwargs)
		self.N_projections = self.N * self.p # number of total hash computations

		self._init_projections()

	def _init_projections(self):
		# set up the gaussian random projection vectors
		np.random.seed(self.seed)
		self.W = np.random.normal(size = (self.N_projections,self.d))
		self.powersOfTwo = np.array([2**i for i in range(self.p)])

	def _hash(self,x):
		# p is the number of concatenated hashes that go into each
		# of the final output hashes
		h = np.sign( np.dot(self.W,x) )
		h = np.clip( h, 0, 1)
		h = np.reshape(h,(self.N,self.p))
		return np.dot(h,self.powersOfTwo)

	def _hashMulti(self,X):
		# p is the number of concatenated hashes that go into each
		# of the final output hashes
		NData = X.shape[0]
		h = np.sign( np.dot(self.W,X.T) )
		h = np.clip( h, 0, 1)
		h = np.reshape(h,(self.N,self.p,NData))
		return np.einsum('ijk,j->ki',h,self.powersOfTwo)
	
	def kernel(self, x, y):
		ip = np.inner(x,y)
		if ip > 0:
			return (1 - 1.0/np.pi * np.arccos(ip/( np.linalg.norm(x)*np.linalg.norm(y) ) ))**self.p
		else:
			return 0
