''' In this experiment, we will look at the convergences of various quantization methods and compare with that of full gradient (dotted line). We will
use 2 machines and give them a batch of size n/2 each 

Rotated LQGSD is the main purpose of this file
'''

import argparse
import os
import numpy as np
import numpy.linalg as LA
import matplotlib.pyplot as plt
from quantization import QSGD, LQSGD, HadamardQuantizer, RLQSGD 

def dme(quantizer, g0, g1): # only two machines now
	ans = (quantizer.compress(g0) + quantizer.compress(g1))/2
	return ans

def grad(A,w,b):
	n = A.shape[0]
	return 2*(A.T@A@w - A.T@b)/n 

def experiment(n,N,d,lr,qlevel,batch_size,alpha,iterations,seed,mode):
	# seeds 
	np.random.seed(seed)
	
	# print to stdout
	print("-----")
	print("0. seed = ", seed)
	print("1. data points(n) = {}, d = {}, workers(N) = {}, batch_size = {}".format(n,d,N,batch_size))
	print("2. qlevel = ", qlevel)
	print("3. alpha = ", alpha)
	print("-----")

	# generating the random dataset
	A = np.random.randn(n,d).astype(np.float64)
	w_star = np.random.randn(d,1).astype(np.float64)
	b = A @ w_star 
	
	# quantizers for variance calculations
	qsgd = QSGD(k = qlevel)
	hadamard = HadamardQuantizer(k = qlevel)
	rlqsgd = RLQSGD(dimension = d, qlevel = qlevel)
	lqsgd = LQSGD(dimension = d, qlevel = qlevel)
	
	# arrays for plotting
	cost = {"LQSGD":[], "RLQSGD":[], "QSGD":[], "full_gradient":[], "hadamard":[]}
	dist = {"LQSGD":[], "RLQSGD":[], "QSGD":[], "full_gradient":[], "hadamard":[]}
	
	for quantizer in ["full_gradient", rlqsgd, qsgd, hadamard, lqsgd]: 
		name = quantizer.name if type(quantizer) is not str else quantizer 
		print(name)
		
		# intitialize weights
		w = np.zeros((d,1)).astype(np.float64)

		for it in range(iterations):
			# to shuffle and get the batches
			indices = np.arange(n)
			np.random.shuffle(indices)
			
			# get a random disjoint set of batch gradients and flatten them
			g0,g1 = grad(A[indices[0:batch_size]],w,b[indices[0:batch_size]]),grad(A[indices[batch_size:]],w,b[indices[batch_size:]])
			g0,g1 = g0.flatten(),g1.flatten()

			# random signing for the rlqsgd
			D = np.sign(np.random.rand(len(rlqsgd.pad(g0))) - 0.5)

			# setting side in the first iteration alone
			if(quantizer == rlqsgd and it == 0):
				quantizer.set_side(2.1 * LA.norm(rlqsgd.HD(g0-g1,D),np.inf)/(qlevel-1))
			if(quantizer == lqsgd and it == 0):
				quantizer.set_side(2.1 * LA.norm(g0-g1,np.inf)/(qlevel-1))

			# getting gradient, loss and ||w-w_star||
			if(quantizer == rlqsgd):
				gradient, diff = quantizer.average(g0,g1,D) 
			elif(quantizer == lqsgd): 
				gradient, diff = quantizer.average(g0,g1)
			else: # to not unkowingly make it unfair, first step of for qsgd, hadamard is taken using full gradient
				gradient = (grad(A,w,b) if (quantizer == "full_gradient" or it == 0) else dme(quantizer,g0,g1))
			loss = LA.norm(A@w - b)**2/n
			cost[name].append(loss)
			dist[name].append(np.linalg.norm(w-w_star))

			# using quantized gradients to set side for the next iteration
			if(quantizer == rlqsgd or quantizer == lqsgd):
				quantizer.set_side(alpha * diff/(qlevel-1)) # the estimated y is alpha/2 * ||Qg0-Qg1||_inf, from which we get side = 2y/q-1
			
			# take the step using reshaped gradient
			w -= lr*gradient.reshape(d,1) 

			# print stats
			if(it % 10 == 0):
				print("iteration = {}, loss = {}".format(it,loss))

	# assert(wrong == 0)
	return dist, cost

def main():
	parser = argparse.ArgumentParser(description='Superlinear Convergence')
	parser.add_argument('--lr', type=float, default=0.8, metavar='LR', help='learning rate (default: 0.8)')
	parser.add_argument('--qlevel', type=int, default=8, metavar='QLEVEL', help='qlevel (default: 8)')
	parser.add_argument('--iterations', type=int, default=40, metavar='ITER', help='iterations (default: 40)')
	parser.add_argument('--nseeds', type=int, default=1, metavar='NSEEDS', help='number of seeds (default: 1)')
	parser.add_argument('--alpha', type=float, default=3.5, metavar='ALPHA', help='alpha (default: 3.5)')
	parser.add_argument('--n', type=int, default=8192, metavar='n', help='number of data points (default: 8192)')
	parser.add_argument('--d', type=int, default=100, metavar='d', help='dimension (default: 2^8 = 256)')	
	parser.add_argument('--save', action='store_true')
	parser.add_argument('--log-scale', action='store_true')
	parser.add_argument('--l2', action='store_true')
	parser.add_argument('--mode', type=str, default="quantized", metavar='MODE', help='mode: central or qunatized')
	args = parser.parse_args()
	
	# parameters for experiment
	n = args.n # number of data points
	d = args.d # dimension
	N = 2 # number of worker nodes
	batch_size = int(n/2) # batches you want to use for the compare
	iterations = args.iterations
	lr = args.lr
	mode = args.mode
	
	# parameters for quantizations
	qlevel = args.qlevel
	alpha = args.alpha
	
	# repeating the experiment for five seeds
	fields = ["LQSGD","RLQSGD","QSGD","full_gradient","hadamard"]
	seeds = [10*i for i in range(args.nseeds)]
	
	dist_avg, cost_avg = {}, {}
	for field in fields:
		dist_avg[field] = np.zeros(iterations)
		cost_avg[field] = np.zeros(iterations)

	for seed in seeds:
		dist, cost = experiment(n,N,d,lr,qlevel,batch_size,alpha,iterations,seed,mode)
		for field in fields:
			dist_avg[field] += np.array(dist[field])
			cost_avg[field] += np.array(cost[field])

	for field in fields:
		dist_avg[field] /= len(seeds)
		cost_avg[field] /= len(seeds)

	# if scale apply corresponding transformation
	if(args.log_scale):
		for field in fields:
			dist_avg[field] = np.log2(dist_avg[field])
			cost_avg[field] = np.log2(cost_avg[field])
	
	''' Plots '''
	iteration = list(range(iterations))
	# (1) log(loss) vs iteration
	start_iter, end_iter = max(0,iterations-30), iterations
	plt.plot(iteration[start_iter:end_iter],cost_avg["LQSGD"][start_iter:end_iter],label='LQSGD (cubic)')
	plt.plot(iteration[start_iter:end_iter],cost_avg["RLQSGD"][start_iter:end_iter],label='RLQSGD (cubic)')
	plt.plot(iteration[start_iter:end_iter],cost_avg["QSGD"][start_iter:end_iter],label='QSGD')
	plt.plot(iteration[start_iter:end_iter],cost_avg["hadamard"][start_iter:end_iter],label='Hadamard')
	plt.plot(iteration[start_iter:end_iter],cost_avg["full_gradient"][start_iter:end_iter],label='GD',linestyle='dashed')
	plt.xlabel('iteration')
	plt.ylabel('log2(loss)' if args.log_scale else 'loss')
	plt.title('Regression convergence: S = {}, n = {}\nd = {}, lr = {}, batch = {}, qlevel = {}'.format(n,N,d,lr,batch_size,qlevel),fontsize=15)
	plt.legend()
	if(args.save):
		if not os.path.isdir('out'):
			os.makedirs('out')
		if(args.log_scale):
			plt.savefig('out/convergence_S_{}_d_{}_log.pdf'.format(args.n,args.d))
		else:
			plt.savefig('out/convergence_S_{}_d_{}.pdf'.format(args.n,args.d))		  
	else:
		plt.show()
	plt.close()
	

if __name__ == '__main__':
	main()
