import numpy as np 
from .hashes.abstract_hash import AbstractHash

class STORM():
	def __init__(self, repetitions:int, hash_range:int, LSH:AbstractHash, dtype = np.int32):
		self.dtype = dtype
		self.R = repetitions # number of ACEs (rows) in the array
		self.W = hash_range  # range of each ACE (width of each row)
		self.LSH = LSH # LSH function
		self.counts = np.zeros((self.R,self.W),dtype = self.dtype)

	def add(self, x): 
		hashvalues = self.LSH.hashData(x)
		for idx, hashvalue in enumerate(hashvalues): 
			rehash = int(hashvalue)
			rehash = rehash % self.W
			self.counts[idx,rehash] += 1

	def addMulti(self, X):
		allhashes = self.LSH.hashDataMulti(X)
		allhashes = np.array(allhashes,dtype = int) % self.W
		for i in range(self.R): 
			self.counts[i,:] += np.bincount(allhashes[:,i],minlength = self.W)

	def clear(self): 
		self.counts = np.zeros((self.R,self.W), dtype = self.dtype)

	def query(self, theta):
		hashvalues = self.LSH.hashModel(theta)
		mean = 0
		N = np.sum(self.counts) / self.R
		for idx, hashvalue in enumerate(hashvalues): 
			rehash = int(hashvalue)
			rehash = rehash % self.W
			mean = mean + self.counts[idx,rehash]
		return mean/(self.R * N)

	def __call__(self, theta):
		return self.query(theta)

	def print(self):
		for i,row in enumerate(self.counts): 
			print(i,'| \t',end = '')
			for thing in row: 
				print(str(int(thing)).rjust(2),end = '|')
			print('\n',end = '')

	def min_size(self):
		mx = np.max(self.counts)
		max_bits = np.ceil(np.log2(mx))
		counter_cost = 4 # bytes
		if max_bits <= 8: 
			counter_cost = 1 
		elif max_bits <= 16: 
			counter_cost = 2

		hash_bits = np.ceil(np.log2(self.W)) # bits
		hash_cost = 4 # bytes
		if hash_bits <= 8: 
			hash_cost = 1 
		elif hash_bits <= 16: 
			hash_cost = 2

		ace_costs = np.count_nonzero(self.counts,axis = 1)
		size = 0
		for ace_cost in ace_costs:
			sparse_cost = ace_cost*(counter_cost + hash_cost)
			dense_cost = self.W*counter_cost
			size += min(sparse_cost,dense_cost)
		return size

