''' In this experiment, we will look at the convergences of various quantization methods and compare with that of full gradient. 
(a) We will use N machines and give them a batch of size n/N each 
(b) One of the N > 2 machines randomly chosen as leader acts as the collector of the quantized gradients and sends stuff back, including an estimate of y '''

import argparse
import os
import numpy as np
import numpy.linalg as LA
import matplotlib.pyplot as plt
from quantization import QSGD, LQSGD, RLQSGD, HadamardQuantizer 
from libsvm.svmutil import * 


def dme(quantizer, gb): # gb is an numpy array of batch gradients
	N, d = gb.shape
	ans = np.zeros(d)
	for i in range(N):
		ans += quantizer.compress(gb[i])
	ans /= N
	if(N > 2):
		ans = quantizer.compress(ans)	
	return ans

def dmeRLQ(quantizer, gb, D): # gb is an numpy array of batch gradients
	N, d = gb.shape
	ans = np.zeros(d)
	for i in range(N):
		ans += quantizer.compress(gb[i], D)
	ans /= N
	if(N > 2):
		ans = quantizer.compress(ans, D)	
	return ans

def numpy_dataset(y, x): # return the dataset in numpy array form
	n, d = len(x), len(x[0])
	A = np.zeros((n,d))
	for i in range(n):
		A[i] = np.array(list(x[i].values())) 
	b = np.array(y).reshape(n,1)
	A, b = A.astype('float64'), b.astype('float64')
	return A, b

def grad(A,w,b):
	n = A.shape[0]
	return 2*(A.T@A@w - A.T@b)/n # just confirm this once

def max_diff(gb):
	N, d = gb.shape
	ans = 0
	for i in range(N):
		ans = max(ans, np.max(np.abs(gb - gb[i])))
	return ans

def experiment(A,b,N,n,d,lr,qlevel,batch_size,alpha,alpha_r,iterations,seed):
	# seeds 
	np.random.seed(seed)
	
	# print to stdout
	print("-----")
	print("0. seed = ", seed)
	print("1. data points (S) = {}, d = {}, workers (n) = {}".format(n,d,N))
	print("2. qlevel = ", qlevel)
	print("3. alpha = ", alpha)
	print("4. alpha_r = ", alpha_r)
	print("-----")

	# optimal w_star
	w_star = LA.inv(A.T@A) @ A.T @ b
	
	# quantizers for variance calculations
	rlqsgd = RLQSGD(dimension = d, qlevel = qlevel)
	lqsgd = LQSGD(dimension = d, qlevel = qlevel)
	qsgd = QSGD(k = qlevel)
	qsgdl2 = QSGD(k = qlevel, L2 = True)
	hadamard = HadamardQuantizer(k = qlevel)

	# arrays for plotting
	dist = {"RLQSGD":[], "LQSGD":[], "QSGD":[], "full_gradient":[], "hadamard":[]} 
	cost = {"RLQSGD":[], "LQSGD":[], "QSGD":[], "full_gradient":[], "hadamard":[]} 

	# for setting lattice parameters
	y = None
		
	# for simplicity I will hardcode for 2 machines
	for quantizer in [rlqsgd, "full_gradient", lqsgd, hadamard, qsgd]: # I hope that this works
		name = quantizer.name if type(quantizer) is not str else quantizer 
		print(name)
		
		# intitialize weight
		w = -1000*np.ones((d,1)).astype(np.float64)

		for it in range(iterations):
			# to shuffle and get the batches
			indices = np.arange(n)
			np.random.shuffle(indices)

			# random signing for the rlqsgd
			after_pad = len(rlqsgd.pad(np.zeros(d)))
			D = np.sign(np.random.rand(after_pad) - 0.5)
			
			# get a random disjoint set of batch gradients and flatten them
			gb = np.zeros((N,d))
			hdgb = np.zeros((N,after_pad))
			for i in range(N):
				gb[i] = grad(A[indices[i*batch_size:(i+1)*batch_size]],w,b[indices[i*batch_size:(i+1)*batch_size]]).flatten()
				hdgb[i] = rlqsgd.HD(gb[i],D) 	

			# if 0th iteration, set y
			if(it == 0):
				if(quantizer == lqsgd):
					y = 1.1 * max_diff(gb)
				if(quantizer == rlqsgd):
					yr = 1.1 * max_diff(hdgb) # yr denotes differences between HDv_i and HDv_j
			
			# checking if successful decoding in the lattices
			if(quantizer == lqsgd):
				assert(max_diff(gb) < y)
			if(quantizer == rlqsgd):
				assert(max_diff(hdgb) < yr)

			# setting lattice parameter using current y 
			if(quantizer == lqsgd):
				if(N > 2):
					lqsgd.set_side(2*y/(qlevel-2))
				else:
					lqsgd.set_side(2*y/(qlevel-1)) # since the final broadcast is not needed
			if(quantizer == rlqsgd):
				if(N > 2):
					rlqsgd.set_side(2*yr/(qlevel-2))
				else:
					rlqsgd.set_side(2*yr/(qlevel-1)) # since the final broadcast is not needed

			# getting gradient, allow other methods to use full gradient in first step
			if(quantizer == lqsgd): 
				gradient = dme(quantizer,gb)
			elif (quantizer == rlqsgd):
				gradient = dmeRLQ(quantizer,gb,D)
			else:
				gradient = (grad(A,w,b) if (quantizer == "full_gradient" or (it == 0)) else dme(quantizer,gb))
			
			# dynamically setting side for lqsgd using quantized gradients
			if(quantizer == lqsgd):
				Qg = np.zeros((N,d))
				for i in range(N):
					Qg[i] = lqsgd.compress(gb[i])
				y = alpha*max_diff(Qg)
			if(quantizer == rlqsgd):
				hdQg = np.zeros((N,after_pad))
				for i in range(N):
					hdQg[i] = rlqsgd.HD(rlqsgd.compress(gb[i],D),D)
				yr = alpha_r*max_diff(hdQg)
			
			# ||w-w_star||, loss
			loss = LA.norm(A@w - b)**2/n
			dist[name].append(LA.norm(w-w_star))
			cost[name].append(loss)

			# step using reshaped gradient, avoid broadcasting errors
			w -= lr*gradient.reshape(d,1) 

			# print stats
			if(it % 100 == 0):
				print("iteration = {}, loss = {}".format(it,loss))

	return dist, cost

def main():
	parser = argparse.ArgumentParser(description='Superlinear Convergence for general number of workers')
	parser.add_argument('--lr', type=float, default=0.1, metavar='LR', help='learning rate (default: 0.1)')
	parser.add_argument('--qlevel', type=int, default=8, metavar='QLEVEL', help='qlevel (default: 8)')
	parser.add_argument('--iterations', type=int, default=60, metavar='ITER', help='iterations (default: 60)')
	parser.add_argument('--dataset', type=str, default='datasets/cpusmall_scale', metavar='DATASET', help='dataset (default: cpusmall_scale)')
	parser.add_argument('--workers', type=int, default=2, metavar='WORKERS', help='workers (default: 2)')
	parser.add_argument('--alpha', type=float, default=3, metavar='ALPHA', help='alpha (default: 3)')
	parser.add_argument('--alpha_r', type=float, default=4, metavar='ALPHA', help='alpha (default: 4)')
	parser.add_argument('--nseeds', type=int, default=1, metavar='ALPHA', help='number of seeds (default: 1)')
	parser.add_argument('--save', action='store_true')
	parser.add_argument('--log-scale', action='store_true')
	args = parser.parse_args()

	# dataset
	y, x = svm_read_problem(args.dataset)
	A, b = numpy_dataset(y,x)
	# A = (A - np.mean(A,0))/np.std(A,0) # normalization
	
	# parameters for experiment
	n, d = A.shape
	N = args.workers # number of worker nodes
	batch_size = int(n/N) # batches you want to use for the compare
	iterations = args.iterations
	lr = args.lr
	
	# parameters for quantizations
	qlevel = args.qlevel
	alpha = args.alpha
	alpha_r = args.alpha_r
	
	# repeating the experiment for five seeds
	fields = ["RLQSGD","LQSGD","QSGD","full_gradient","hadamard"]
	seeds = [10*i for i in range(args.nseeds)]
	dist_avg = {}
	cost_avg = {} 
	for field in fields:
		dist_avg[field] = np.zeros(iterations)
		cost_avg[field] = np.zeros(iterations)

	for seed in seeds:
		dist, cost = experiment(A,b,N,n,d,lr,qlevel,batch_size,alpha,alpha_r,iterations,seed)
		for field in fields:
			dist_avg[field] += np.array(dist[field])
			cost_avg[field] += np.array(cost[field])

	for field in fields:
		dist_avg[field] /= len(seeds)
		cost_avg[field] /= len(seeds)
	
	if(args.log_scale):
		for field in fields:
			dist_avg[field] = np.log10(dist_avg[field])
			cost_avg[field] = np.log10(cost_avg[field])
	
	''' Plots '''
	iteration = list(range(iterations))
	# (1) loss vs iteration
	start_iter, end_iter = max(0,iterations-300), iterations
	# plt.plot(iteration[start_iter:end_iter],cost_avg["RLQSGD"][start_iter:end_iter],label='RLQSGD (cubic)')
	plt.plot(iteration[start_iter:end_iter],cost_avg["LQSGD"][start_iter:end_iter],label='LQSGD (cubic)')
	plt.plot(iteration[start_iter:end_iter],cost_avg["QSGD"][start_iter:end_iter],label='QSGD')
	plt.plot(iteration[start_iter:end_iter],cost_avg["hadamard"][start_iter:end_iter],label='Hadamard')
	plt.plot(iteration[start_iter:end_iter],cost_avg["full_gradient"][start_iter:end_iter],label='GD',linestyle='dashed')
	plt.xlabel('iteration')
	plt.ylabel('$\log_{10}$(loss)' if args.log_scale else 'loss')
	plt.title('Real dataset Regression: S = {}, n = {}, d = {}\nlr = {}, batch = {}, qlevel = {}'.format(n,N,d,lr,batch_size,qlevel), fontsize=15)
	plt.legend()
	if(args.save):
		if not os.path.isdir('out'):
			os.makedirs('out')
		plt.savefig('out/convergence_lr_{}_n_{}_q_{}_alpha_{}.pdf'.format(args.lr,args.workers,args.qlevel,args.alpha))  
	else:
		plt.show()
	plt.close()
	

if __name__ == '__main__':
	main()
