import numpy as np
import sklearn.utils
from scipy.stats import norm
import math


class RACE():
	def __init__(self, num_reservoirs, reservoir_size, debug = False):
		# K = number of hashes
		self.R = num_reservoirs # range of rehash table (each cell is a reservoir)
		self.M = reservoir_size # size of each reservoir
		self.reservoirs = [[] for _ in range(self.R)]
		self.counts = np.zeros(self.R)
		self.debug = debug

	def add(self, vector, hashfn, dataID = None):
		# If dataID is not None, then we append the dataID instead of the entire vector
		hashes = hashfn.hash(vector)
		rehash = sklearn.utils.murmurhash3_32(np.array(hashes,dtype=np.int32))[0]
		rehash = rehash % self.R
		self.counts[rehash] += 1

		if dataID is not None:
			appendee = dataID
		else:
			appendee = vector

		# Now do the reservoir sampling thang
		if self.counts[rehash] <= self.M: # M is reservoir size
			self.reservoirs[rehash].append(appendee) # add to it
		else:
			index = np.random.randint(0,self.counts[rehash])
			if index < self.M:
				self.reservoirs[rehash][index] = appendee # replace the random element by appendee

		if self.debug == True:
			print('--------------------------------------------------------------------------')
			print(vector,'---> @',rehash)
			for idx,reservoir in enumerate(self.reservoirs):
				print(idx,reservoir,self.counts[idx])


	def samples(self):
		return [item for sublist in self.reservoirs for item in sublist]

	def clear(self):
		self.counts = np.zeros(self.R)
		self.reservoirs = [[] for _ in range(self.R)]


class RACECounts():
	def __init__(self, repetitions, hash_range, debug = False):
		# K = number of hashes
		self.R = repetitions # number of ACEs
		self.M = hash_range  # range of each ACE
		self.counts = np.zeros((self.R,self.M))

	def add(self, hashvalues):
		# from multiprocessing import Pool
		# global sketchIns
		# def sketchIns(counts, hashvalues):
		for idx, hashvalue in enumerate(hashvalues):
			rehash = int(hashvalue)
			rehash = rehash % self.M
			self.counts[idx,rehash] += 1

		# sketchIns(self.counts, hashvalues)
		# def sketchInsMap(counts, hashvalues,processes):
		# 	pool = Pool(processes=processes)
		# 	chunks = list(zip(np.array_split(counts, processes), np.array_split(hashvalues, processes)))
		# 	pool.map(sketchIns, chunks)

		# sketchInsMap(self.counts, hashvalues, 10)
		# print (self.counts[3,:])


	def remove(self, hashvalues):
		for idx, hashvalue in enumerate(hashvalues):
			rehash = int(hashvalue)
			rehash = rehash % self.M
			self.counts[idx,rehash] += -1

	def clear(self):
		self.counts = np.zeros((self.R,self.M))

	def query(self, hashvalues):
		mean = 0
		for idx, hashvalue in enumerate(hashvalues):
			rehash = int(hashvalue)
			rehash = rehash % self.M
			mean = mean + self.counts[idx,rehash]
		return mean/self.R

	def print(self):
		for i,row in enumerate(self.counts):
			print(i,'| \t',end = '')
			for thing in row:
				print(str(int(thing)).rjust(2),end = '|')
			print('\n',end = '')

	def counts(self):
		return self.counts

	def min_size(self,reps):
		mx = np.max(self.counts)
		max_bits = np.ceil(np.log2(mx))
		counter_cost = 4 # bytes
		if max_bits <= 8:
			counter_cost = 1
		elif max_bits <= 16:
			counter_cost = 2

		hash_bits = np.ceil(np.log2(self.M)) # bits
		hash_cost = 4 # bytes
		if hash_bits <= 8:
			hash_cost = 1
		elif hash_bits <= 16:
			hash_cost = 2

		ace_costs = np.count_nonzero(self.counts,axis = 1)
		size = 0
		r = 0
		for ace_cost in ace_costs:
			if r>reps:
				break
			sparse_cost = ace_cost*(counter_cost + hash_cost)
			dense_cost = self.M*counter_cost
			size += min(sparse_cost,dense_cost)
			r+=1
		return size

	def equivalent_size(self,reps, sampl, K):
		aces = np.empty(0)
		for i in sampl:
		    st = max(0,i-K)
		    en = min(i+K,self.counts.shape[0]-1)
		    aces = np.union1d(aces, np.arange(st,en))
		aces = aces.astype(int)
		ace_costs = np.count_nonzero(self.counts[aces],axis = 1)

		size = 0
		counter_cost = 8
		hash_cost = np.log2(self.M)
		for ace_cost in ace_costs:
			sparse_cost = ace_cost*(hash_cost + counter_cost)
			min_cost = 1*hash_cost
			dense_cost = self.M*counter_cost
			size += min(sparse_cost,dense_cost, min_cost)
		return size

class Reservoir():
	def __init__(self, reservoir_size):
		self.reservoir = []
		self.M = reservoir_size
		self.N = 0

	def add(self, vector, dataID = None):
		self.N += 1
		if dataID is not None:
			appendee = dataID
		else:
			appendee = vector
		if self.N <= self.M:
			self.reservoir.append(appendee)
		else:
			index = np.random.randint(0,self.N)
			if index < self.M:
				self.reservoir[index] = appendee

	def samples(self):
		return self.reservoir

	def clear():
		self.N = 0
		self.reservoir = []


class L2LSH():
	def __init__(self, N, d, r):
		# N = number of hashes
		# d = dimensionality
		# r = "bandwidth"
		self.N = N
		self.d = d
		self.r = r

		# set up the gaussian random projection vectors
		self.W = np.random.normal(size = (N,d))
		self.b = np.random.uniform(low = 0,high = r,size = N)


	def hash(self,x):
		return np.floor( (np.dot(self.W,x) + self.b)/self.r )


class SRP():
	def __init__(self, N, d):
		# N = number of hashes
		# d = dimensionality of x vector
		# r = "bandwidth"
		self.N = N
		self.d = d


		# set up the gaussian random projection vectors
		self.W = np.random.normal(size = (N,d)) # N random hyperplanes in d dimention
		self.powersOfTwo = np.array([2**i for i in range(self.N)])

	def hash(self,x):
		h = np.sign( np.dot(self.W,x) )
		h = np.clip( h, 0, 1)
		return np.dot( h, self.powersOfTwo)

	def hash_independent(self,x,p = 1): # p is number of bits
		# print (self.W.shape, x.shape)
		h = np.sign( np.dot(self.W,x) )
		h = np.clip( h, 0, 1) # bit stream of 1s & 0s
		if p > 1:
			h = np.reshape(h,(-1,p)) # reps x p
			n_hashes = h.shape[0] # reps
			powersOfTwo = np.array([2**i for i in range(p)])
			codes = np.zeros(n_hashes)
			for idx,hi in enumerate(h):
				codes[idx] = np.dot(hi,powersOfTwo) # location for each rep
			return ([codes, h])
		else:
			return(h)

	# def hash_independent_withnear(self,x,p = 1, fx):
	# 	h = np.sign( np.dot(self.W,x) )
	# 	h = np.clip( h, 0, 1)
	# 	Codes = []
	# 	if p > 1:
	# 		h = np.reshape(h,(-1,p)) # num/3 cols p rows
	# 		n_hashes = h.shape[0]
	# 		powersOfTwo = np.array([2**i for i in range(p)])
	# 		codes = np.zeros(n_hashes)
	# 		for idx,hi in enumerate(h):
	# 			codes[idx] = np.dot(hi,powersOfTwo)
	# 		# return codes
	# 		Codes.append(codes)
	#
	# 		# selct a random row > flip bit (1 edit dist)
	# 		for i in range(0,100):
	# 			[h_p, x] = perturb(h, i)
	# 			n_hashes = h_p.shape[0]
	# 			codes = np.zeros(n_hashes)
	# 			for idx,hi in enumerate(h_p):
	# 				codes[idx] = np.dot(hi,powersOfTwo)
	# 			# return codes
	# 			Codes.append(codes)
	# 	return Codes
	# 	else:
	# 		return(h)


def P_L2(c,w):
    try:
        p = [1 - 2*norm.cdf(-w/ci) - 2.0/(math.sqrt(2*math.pi)*(w/ci)) * (1 - math.exp(-0.5*(w**2)/(ci**2)) ) for ci in c ]
        p = np.array(p)
    except TypeError:
        p = 1 - 2*norm.cdf(-w/c) - 2.0/(math.sqrt(2*math.pi)*(w/c)) * (1 - math.exp(-0.5*(w**2)/(c**2)) )
    return p
