import numpy as np
import numpy.linalg as LA

''' I have made everything to return np arrays and take np array as input '''

''' QSGD '''
class QSGD():
	def __init__(self, k, L2 = False):
		self.name = "QSGDl2" if L2 else "QSGD"
		self.k = k
		self.L2 = L2

	def compress(self,a):
		fmin = 0 if self.L2 else np.min(a) 
		fmax = LA.norm(a) if self.L2 else np.max(a)

		res = np.zeros(len(a))
		for i in range(len(a)):
			unit = (fmax - fmin) / (self.k - 1)
			if fmax - fmin == 0:
				q = fmin
			else:
				v = np.floor((a[i] - fmin) / unit + np.random.rand())
				q = fmin + v * unit
			res[i] = q
		return res

''' Hadamard Quantizer '''
class HadamardQuantizer():
	def __init__(self, k):
		self.name = "hadamard"
		self.k = k

	def hadamard_transform(self, a):
		b = np.copy(a)
		h = 1
		while h < len(b):
			for i in range(0, len(b), h * 2):
				for j in range(i, i + h):
					x = b[j]
					y = b[j + h]
					b[j] = x + y
					b[j + h] = x - y
					b[j] /= np.sqrt(2)
					b[j + h] /= np.sqrt(2)
			h *= 2
		return b

	def pad(self, a):
		k = 1
		while k < len(a):
			k *= 2
		res = np.zeros(k)
		res[:len(a)] = a[:len(a)]
		return res
	
	def compress(self, a):
		b = self.pad(a)
		D = np.sign(np.random.rand(len(b))-0.5)
		b *= D
		b = self.hadamard_transform(b)


		fmin = np.min(b)
		fmax = np.max(b)
		for i in range(len(b)):
			unit = (fmax - fmin) / (self.k - 1)
			if fmax - fmin == 0:
				q = fmin
			else:
				v = np.floor((b[i] - fmin) / unit + np.random.rand())
				q = fmin + v * unit
			b[i] = q

		b = self.hadamard_transform(b)
		b *= D
		b = b[0:len(a)]

		return b


''' Hypercube lattice without error correction: note that this is technically just a simulation i.e we assume that the encoding and decoding is 
happening correctly '''
class LQSGD():
	def __init__(self, dimension, qlevel, side = None): # args and kwargs will have all the parameters we don't care but are there in base class
		self.name = "LQSGD"
		self.d = dimension
		self.q = qlevel # mod q coloring
		self.side = side  

	def compress(self,a):
		g1 = np.array(a)
		u1 = (np.random.rand(self.d) - 0.5) * self.side
		return g1 + u1		

	def set_side(self, side):
		self.side = side

	def average(self, a, b): # it needs the random signing vector as input
		g0, g1 = np.array(a), np.array(b)
		qg0, qg1 = self.compress(g0), self.compress(g1) 
		
		if(2 * LA.norm(g0-g1,np.inf) >= (self.q-1) * self.side):
			print("Decode Error LQSGD, side is wrong")
			exit(0)
		
		avg = (qg0 + qg1)/2
		diff = LA.norm(qg0 - qg1, np.inf)  

		return avg, diff



''' Hypercube lattice with Hadamard rotation'''
class RLQSGD():
	def __init__(self, dimension, qlevel, side = None): # args and kwargs will have all the parameters we don't care but are there in base class
		self.name = "RLQSGD"
		self.d = dimension
		self.q = qlevel # mod q coloring
		self.side = side  
	
	def hadamard_transform(self, a):
		b = np.copy(a)
		h = 1
		while h < len(b):
			for i in range(0, len(b), h * 2):
				for j in range(i, i + h):
					x = b[j]
					y = b[j + h]
					b[j] = x + y
					b[j + h] = x - y
					b[j] /= np.sqrt(2)
					b[j + h] /= np.sqrt(2)
			h *= 2
		return b
	
	def pad(self, a):
		k = 1
		while k < len(a):
			k *= 2
		res = np.zeros(k)
		res[:len(a)] = a[:len(a)]
		return res
	
	def compress(self,a,D): # took a common D for all the machines for convinience
		g1 = np.array(a)      
		b = self.pad(g1)
		b *= D
		b = self.hadamard_transform(b)
		
		u1 = (np.random.rand(len(b)) - 0.5) * self.side
		b = b + u1
	  
		b = self.hadamard_transform(b)
		b *= D
		b = b[0:len(a)]
		return b


	def HD(self, v, D): # useful for returning diff in the average function used to update y in the main program
		b = self.pad(v)
		b *= D
		b = self.hadamard_transform(b)
		return b

	def average(self, a, b, D): # it needs the random signing vector as input
		g0, g1 = np.array(a), np.array(b)
		qg0, qg1 = self.compress(g0, D), self.compress(g1, D) 
		
		if(2 * LA.norm(self.HD(g0-g1,D),np.inf) >= (self.q-1) * self.side):
			print("Decode Error RLQSGD, side is wrong")
			exit(0)
		
		avg = (qg0 + qg1)/2
		diff = LA.norm(self.HD(qg0 - qg1, D), np.inf)  

		return avg, diff

	def set_side(self, side):
		self.side = side

	
''' vQSGD using cross polytope scheme '''
class vQSGD():
	def __init__(self, repetition, dimension):
		self.name = "vQSGD"
		self.repetition = repetition
		self.d = dimension

	def compress(self,a):
		grad = np.array(a)
		d = self.d
		
		# The first d of them are sqrt{d} e_i . and last d are -sqrt{d} e_i
		prob = np.zeros(2*d)
		B = LA.norm(grad)
		x = grad/B # project into the unit ball 
		gamma = 1 - (LA.norm(x,1)/np.sqrt(d))
		
		# find the coefficients of the convex combination i.e probabilities
		for i in range(d):
			if x[i] >= 0:
				prob[i] = (x[i]/np.sqrt(d)) + (gamma/(2*d))
				prob[i+d] = (gamma/(2*d))
			else:
				prob[i] = (gamma/(2*d))
				prob[i+d] = (-x[i]/np.sqrt(d)) +(gamma/(2*d))

		# Qgrad array: do repetitions and take average 		
		Qgrad = np.zeros(d)
		for _ in range(self.repetition):
			index = np.random.choice(2*d,p=prob)
			if index < d :
				Qgrad[index] += (np.sqrt(d)*B)
			else:
				Qgrad[index - d] -= (np.sqrt(d)*B)
		Qgrad /= self.repetition

		return Qgrad