''' In this experiment, we will look at the convergences of various quantization methods and compare with that of full gradient (dotted line). We will
use 2 machines and give them a batch of size n/2 each 
'''

import argparse
import os
import numpy as np
import numpy.linalg as LA
import matplotlib.pyplot as plt
from quantization import QSGD, RLQSGD, LQSGD, HadamardQuantizer 

def dme(quantizer, g0, g1): # only two machines now
	ans = (quantizer.compress(g0) + quantizer.compress(g1))/2	
	return ans

def grad(A,w,b):
	n = A.shape[0]
	return 2*(A.T@A@w - A.T@b)/n # just confirm this once

def experiment(n,N,d,lr,qlevel,batch_size,alpha,iterations,seed,mode,rep):
	# seeds 
	np.random.seed(seed)
	
	# print to stdout
	print("-----")
	print("0. seed = ", seed)
	print("1. data points(n) = {}, d = {}, workers(N) = {}, batch_size = {}".format(n,d,N,batch_size))
	print("2. qlevel = ", qlevel)
	print("3. alpha = ", alpha)
	print("-----")

	# generating the random dataset
	A = np.random.randn(n,d).astype(np.float64)
	w_star = np.random.randn(d,1).astype(np.float64)
	b = A @ w_star 
	
	# quantizers for variance calculations
	rlqsgd = RLQSGD(dimension = d, qlevel = qlevel)
	lqsgd = LQSGD(dimension = d, qlevel = qlevel)
	qsgd = QSGD(k = qlevel)
	hadamard = HadamardQuantizer(k = qlevel)
	
	# arrays for plotting
	quant_error = {"RLQSGD":[], "LQSGD":[], "QSGD":[], "full_gradient":[], "hadamard":[]}
	cost = {"RLQSGD":[], "LQSGD":[], "QSGD":[], "full_gradient":[], "hadamard":[]}
	dist = {"RLQSGD":[], "LQSGD":[], "QSGD":[], "full_gradient":[], "hadamard":[]}
	
	# used for lqsgd 	
	wrong = 0
		
	for quantizer in ["full_gradient", qsgd, lqsgd, hadamard, rlqsgd]: 
		name = quantizer.name if type(quantizer) is not str else quantizer 
		print(name)
		
		# intitialize weights
		w = np.zeros((d,1)).astype(np.float64)

		for it in range(iterations):
			g0,g1,g = np.zeros((d,1)), np.zeros((d,1)), np.zeros((d,1)) 
			# to shuffle and get the batches
			w0, w1, wg = np.copy(w), np.copy(w), np.copy(w)
			for j in range(rep):
				indices = np.arange(n)
				np.random.shuffle(indices)
				
				t0 = grad(A[indices[0:batch_size]],w0,b[indices[0:batch_size]])
				g0 += t0
				w0 -= lr*t0.reshape(d,1) 

				t1 = grad(A[indices[batch_size:]],w1,b[indices[batch_size:]])
				g1 += t1
				w1 -= lr*t1.reshape(d,1)

				tg = grad(A,wg,b)
				g += tg
				wg -= lr*tg.reshape(d,1)

			# flatten
			g0,g1,g = g0.flatten(),g1.flatten(),g.flatten()


			# random signing for the rlqsgd
			D = np.sign(np.random.rand(len(rlqsgd.pad(g0))) - 0.5)

			# setting side in the first iteration alone
			if(quantizer == rlqsgd and it == 0):
				quantizer.set_side(2.1 * LA.norm(rlqsgd.HD(g0-g1,D),np.inf)/(qlevel-1))
			if(quantizer == lqsgd and it == 0):
				quantizer.set_side(2.1 * LA.norm(g0-g1,np.inf)/(qlevel-1))

			# getting gradient, loss and ||w-w_star||
			if(quantizer == rlqsgd):
				gradient, diff = quantizer.average(g0,g1,D) 
			elif(quantizer == lqsgd): 
				gradient, diff = quantizer.average(g0,g1)
			else: # to not unkowingly make it unfair, first step of for qsgd, hadamard is taken using full gradient
				# gradient = (grad(A,w,b) if (quantizer == "full_gradient" or it == 0) else dme(quantizer,g0,g1))
				gradient = (g if (quantizer == "full_gradient" or it == 0) else dme(quantizer,g0,g1))

			qe = LA.norm(gradient-(g0+g1)/2) # quantization error
			loss = LA.norm(A@w - b)**2/n
			quant_error[name].append(qe)
			cost[name].append(loss)
			dist[name].append(np.linalg.norm(w-w_star))

			# using quantized gradients to set side for the next iteration
			if(quantizer == rlqsgd or quantizer == lqsgd):
				quantizer.set_side(alpha * diff/(qlevel-1)) # the estimated y is alpha/2 * ||Qg0-Qg1||_inf, from which we get side = 2y/q-1
			
			# take the step using reshaped gradient
			w -= lr*gradient.reshape(d,1) 

			# print stats
			if(it % 10 == 0):
				print("iteration = {}, loss = {}".format(it,loss))

	return dist, cost, quant_error


def main():
	parser = argparse.ArgumentParser(description='Superlinear Convergence')
	parser.add_argument('--lr', type=float, default=0.1, metavar='LR', help='learning rate (default: 0.1)')
	parser.add_argument('--qlevel', type=int, default=8, metavar='QLEVEL', help='qlevel (default: 8)')
	parser.add_argument('--iterations', type=int, default=20, metavar='ITER', help='iterations (default: 20)')
	parser.add_argument('--rep', type=int, default=10, metavar='REP', help='interval of taking sum (default: 10)')
	parser.add_argument('--nseeds', type=int, default=1, metavar='NSEEDS', help='number of seeds (default: 1)')
	parser.add_argument('--alpha', type=float, default=1.5, metavar='ALPHA', help='alpha (default: 1.5)')
	parser.add_argument('--n', type=int, default=2**13, metavar='n', help='number of data points (default: 2^13 = 8192)')
	parser.add_argument('--d', type=int, default=100, metavar='d', help='dimension (default: 100)')
	parser.add_argument('--outdir', type=str, default='out', metavar='OUTDIR', help='directory to save pics (default: out)')		
	parser.add_argument('--save', action='store_true')
	parser.add_argument('--log-scale', action='store_true')
	parser.add_argument('--l2', action='store_true')
	parser.add_argument('--mode', type=str, default="quantized", metavar='MODE', help='mode: central or qunatized')
	args = parser.parse_args()
	
	# parameters for experiment
	n = args.n # number of data points
	d = args.d # dimension
	N = 2 # number of worker nodes
	batch_size = int(n/2) # batches you want to use for the compare
	iterations = args.iterations
	lr = args.lr
	mode = args.mode
	
	# parameters for quantizations
	qlevel = args.qlevel
	alpha = args.alpha
	
	# repeating the experiment for five seeds
	fields = ["RLQSGD","LQSGD","QSGD","full_gradient","hadamard"]
	seeds = [10*i for i in range(args.nseeds)]
	
	dist_avg, cost_avg, quant_error_avg = {}, {}, {}
	for field in fields:
		dist_avg[field] = np.zeros(iterations)
		cost_avg[field] = np.zeros(iterations)
		quant_error_avg[field] = np.zeros(iterations)

	for seed in seeds:
		dist, cost, quant_error = experiment(n,N,d,lr,qlevel,batch_size,alpha,iterations,seed,mode,args.rep)
		for field in fields:
			dist_avg[field] += np.array(dist[field])
			cost_avg[field] += np.array(cost[field])
			quant_error_avg[field] += np.array(quant_error[field])

	for field in fields:
		dist_avg[field] /= len(seeds)
		cost_avg[field] /= len(seeds)
		quant_error_avg[field] /= len(seeds)

	# if scale apply corresponding transformation
	if(args.log_scale):
		for field in fields:
			dist_avg[field] = np.log10(dist_avg[field])
			cost_avg[field] = np.log10(cost_avg[field])
			quant_error_avg[field] = np.log10(quant_error_avg[field])
	
	''' Plots '''
	iteration = list(range(iterations))
	# (1) loss vs iteration
	start_iter, end_iter = 0, iterations
	plt.plot(iteration[start_iter:end_iter],cost_avg["RLQSGD"][start_iter:end_iter],label='RLQSGD (cubic)')
	# plt.plot(iteration[start_iter:end_iter],cost_avg["LQSGD"][start_iter:end_iter],label='LQSGD (cubic)')
	plt.plot(iteration[start_iter:end_iter],cost_avg["QSGD"][start_iter:end_iter],label='QSGD')
	plt.plot(iteration[start_iter:end_iter],cost_avg["hadamard"][start_iter:end_iter],label='Hadamard')
	plt.plot(iteration[start_iter:end_iter],cost_avg["full_gradient"][start_iter:end_iter],label='GD',linestyle='dashed')
	plt.xlabel('iteration')
	plt.ylabel('log10(loss)' if args.log_scale else 'loss')
	plt.title('Local SGD Convergence: S = {}, n = {}\nd = {}, lr = {}, batch = {}, q = {}, rep = {}'.format(n,N,d,lr,batch_size,qlevel,args.rep),fontsize=15)
	plt.legend()
	if(args.save):
		if not os.path.isdir(args.outdir):
			os.makedirs(args.outdir)
		if(args.log_scale):
			plt.savefig(args.outdir + '/convergence_S_{}_d_{}_log.pdf'.format(args.n,args.d))
		else:
			plt.savefig(args.outdir + '/convergence_S_{}_d_{}.pdf'.format(args.n,args.d))		  
	else:
		plt.show()
	plt.close()

	# (2) quantization error vs iteration
	start_iter, end_iter = 0, iterations
	plt.plot(iteration[start_iter:end_iter],quant_error_avg["RLQSGD"][start_iter:end_iter],label='RLQSGD (cubic)')
	# plt.plot(iteration[start_iter:end_iter],quant_error_avg["LQSGD"][start_iter:end_iter],label='LQSGD (cubic)')
	plt.plot(iteration[start_iter:end_iter],quant_error_avg["QSGD"][start_iter:end_iter],label='QSGD')
	plt.plot(iteration[start_iter:end_iter],quant_error_avg["hadamard"][start_iter:end_iter],label='Hadamard')
	# plt.plot(iteration[start_iter:end_iter],quant_error_avg["full_gradient"][start_iter:end_iter],label='GD',linestyle='dashed')
	plt.xlabel('iteration')
	plt.ylabel('log10(quantization error)' if args.log_scale else 'quantization error')
	plt.title('Local SGD quantization error: S = {}, n = {}\nd = {}, lr = {}, batch = {}, q = {}, rep = {}'.format(n,N,d,lr,batch_size,qlevel,args.rep),fontsize=15)
	plt.legend()
	if(args.save):
		if not os.path.isdir(args.outdir):
			os.makedirs(args.outdir)
		if(args.log_scale):
			plt.savefig(args.outdir + '/qe_S_{}_d_{}_log.pdf'.format(args.n,args.d))
		else:
			plt.savefig(args.outdir + '/qe_S_{}_d_{}.pdf'.format(args.n,args.d))		  
	else:
		plt.show()
	plt.close()
	

if __name__ == '__main__':
	main()
