''' In this experiment, we will look at the variances of various methods from the full gradient g. We will
use 2 machines and give them a batch of size n/2 each. Descents will be common and will be made along the 
full gradient 
'''

import argparse
import os
import numpy as np
import matplotlib.pyplot as plt
import numpy.linalg as LA
from quantization import QSGD, LQSGD, RLQSGD, HadamardQuantizer   

def variance_calc(sample_vectors, expected_vector):
	var = 0
	for vector in sample_vectors:
		var += LA.norm(expected_vector - vector)**2
	return var/len(sample_vectors)

def dme(quantizer, g0, g1): # only two machines now
	ans = (quantizer.compress(g0) + quantizer.compress(g1))/2	
	return ans

def grad(A,w,b):
	n = A.shape[0]
	return 2*(A.T@A@w - A.T@b)/n 


def random_batch_gradient(A,w,b,batch_size):
	indices = np.arange(A.shape[0])
	np.random.shuffle(indices)
	return grad(A[indices[0:batch_size]],w,b[indices[0:batch_size]]).flatten()

def experiment(n,N,d,lr,qlevel,batch_size,alpha,iterations,seed,mode):
	# seeds 
	np.random.seed(seed)
	
	# print to stdout
	print("-----")
	print("0. seed = ", seed)
	print("1. data points (n) = {}, d = {}, workers (N) = {}, batch_size = {}".format(n,d,N,batch_size))
	print("2. qlevel = ", qlevel)
	print("3. alpha = ", alpha)
	print("-----")

	# generating random dataset
	A = np.random.randn(n,d).astype(np.float64)
	w_star = np.random.randn(d,1).astype(np.float64)
	b = A @ w_star 

	# quantizers for variance calculations
	rlqsgd = RLQSGD(dimension = d, qlevel = qlevel, side = 0)
	lqsgd = LQSGD(dimension = d, qlevel = qlevel, side = 0)
	qsgd = QSGD(k = qlevel)
	hadamard = HadamardQuantizer(k = qlevel)
	qsgdl2 = QSGD(k = qlevel, L2 = True)

	# arrays for plotting
	result = {
		"input_var_list": [], # this is basically sigma i.e E[||g1-g||^2]
		"rlqsgd_var_list": [],
		"lqsgd_var_list": [],
		"qsgd_var_list": [],
		"hadamard_var_list": [],		
		"qsgdl2_var_list": []
	}
	
	# initialize weights, and number of mistakes
	wrong = 0
	w = np.zeros((d,1)).astype(np.float64) 

	''' Start of experiment '''
	for it in range(iterations):
		# to shuffle and get the batches
		indices = np.arange(n)
		np.random.shuffle(indices)

		# get the full gradient and the batch gradients and flatten them 
		g,g0,g1 = grad(A,w,b),grad(A[indices[0:batch_size]],w,b[indices[0:batch_size]]),grad(A[indices[batch_size:]],w,b[indices[batch_size:]])
		g,g0,g1 = g.flatten(),g0.flatten(),g1.flatten()

		# random signing for the rlqsgd
		D = np.sign(np.random.rand(len(rlqsgd.pad(g0))) - 0.5)
 
		# setting side in the first iteration alone
		if(it == 0):
			rlqsgd.set_side(2.1 * LA.norm(rlqsgd.HD(g0-g1,D),np.inf)/(qlevel-1))
			lqsgd.set_side(2.1 * LA.norm(g0-g1,np.inf)/(qlevel-1))

		# Update: made everyone to actually calculate variance
		repeats = 5
		input_vec_list, qsgd_vec_list, hadamard_vec_list, qsgdl2_vec_list, hadamardl2_vec_list, rlqsgd_vec_list, lqsgd_vec_list= [], [], [], [], [], [], []
		for i in range(repeats):
			indices = np.arange(n)
			np.random.shuffle(indices)
			t0,t1 = grad(A[indices[0:batch_size]],w,b[indices[0:batch_size]]).flatten(),grad(A[indices[batch_size:]],w,b[indices[batch_size:]]).flatten()
			input_vec_list.append(t0)
			qsgd_vec_list.append(dme(qsgd,t0,t1))
			hadamard_vec_list.append(dme(hadamard,t0,t1))
			qsgdl2_vec_list.append(dme(qsgdl2,t0,t1))
			# for rlqsgd
			avg, diff = rlqsgd.average(g0,g1,D)
			rlqsgd_vec_list.append(avg)
			# for lqsgd
			avg, diff1 = lqsgd.average(g0,g1)
			lqsgd_vec_list.append(avg)

		input_var = variance_calc(input_vec_list,g) 
		rlqsgd_var = variance_calc(rlqsgd_vec_list, g)
		lqsgd_var = variance_calc(lqsgd_vec_list, g)
		qsgd_var = variance_calc(qsgd_vec_list, g)
		hadamard_var = variance_calc(hadamard_vec_list, g)
		qsgdl2_var = variance_calc(qsgdl2_vec_list, g)
		
		result["input_var_list"].append(input_var)
		result["rlqsgd_var_list"].append(rlqsgd_var)
		result["lqsgd_var_list"].append(lqsgd_var)
		result["qsgd_var_list"].append(qsgd_var)
		result["hadamard_var_list"].append(hadamard_var)
		result["qsgdl2_var_list"].append(qsgdl2_var)

		rlqsgd.set_side(alpha * diff/(qlevel-1)) # the estimated y is alpha/2 * ||Qg0-Qg1||_inf, from which we get side = 2y/q-1
		lqsgd.set_side(alpha * diff1/(qlevel-1)) # the estimated y is alpha/2 * ||Qg0-Qg1||_inf, from which we get side = 2y/q-1
		
		# it is not necessary to keep track of errors as long as the convergence experiment does not have errors
		if(it % 10 == 0):
			loss = LA.norm(A@w - b)**2/n
			print("iteration = {}, loss = {}".format(it,loss))

		# take the step using the full gradient
		w -= lr*g.reshape(d,1)	 
	
	assert(wrong == 0)
	return result 


def main():
	parser = argparse.ArgumentParser(description='Superlinear Variance')
	parser.add_argument('--lr', type=float, default=0.1, metavar='LR', help='learning rate (default: 0.1)')
	parser.add_argument('--qlevel', type=int, default=8, metavar='S', help='qlevel (default: 8)')
	parser.add_argument('--iterations', type=int, default=40, metavar='ITER', help='iterations (default: 40)')
	parser.add_argument('--nseeds', type=int, default=1, metavar='NSEEDS', help='number of seeds (default: 1)')
	parser.add_argument('--alpha', type=float, default=3, metavar='ALPHA', help='alpha (default: 3)')
	parser.add_argument('--n', type=int, default=8192, metavar='n', help='number of data points (default: 8192)')
	parser.add_argument('--d', type=int, default=100, metavar='d', help='dimension (default: 100)')		
	parser.add_argument('--save', action='store_true')
	parser.add_argument('--log-scale', action='store_true')
	parser.add_argument('--l2', action='store_true')
	parser.add_argument('--mode', type=str, default="quantized", metavar='MODE', help='mode: central or qunatized')
	args = parser.parse_args()
	
	# parameters for experiment
	n = args.n # number of data points
	d = args.d # dimension
	N = 2 # number of worker nodes
	batch_size = int(n/2) # batches you want to use for the compare
	iterations = args.iterations
	lr = args.lr
	
	# parameters for quantizations
	qlevel = args.qlevel
	alpha = args.alpha

	# repeating the experiment for five seeds
	fields = ["input_var_list","rlqsgd_var_list","lqsgd_var_list","qsgd_var_list","hadamard_var_list","qsgdl2_var_list"]
	seeds = [10*i for i in range(args.nseeds)]
	# to prevent by reference in python
	result_avg = {}
	# intialize everything with zero vector
	for field in fields:
		result_avg[field] = np.zeros(iterations)

	for seed in seeds:
		result = experiment(n,N,d,lr,qlevel,batch_size,alpha,iterations,seed,args.mode)
		for field in fields:
			result_avg[field] += np.array(result[field])

	for field in fields:
		result_avg[field] /= len(seeds)

	# if scale apply corresponding transformation
	if(args.log_scale):
		for field in fields:
			result_avg[field] = np.log10(result_avg[field])
	
	''' Plots '''
	iteration = list(range(iterations))
	
	# (1) variance wrt full gradient for the two quantization methods
	start_iter = max(0,iterations - 50)
	plt.plot(iteration[start_iter:],result_avg["input_var_list"][start_iter:],label='Input variance')
	plt.plot(iteration[start_iter:],result_avg["rlqsgd_var_list"][start_iter:],label='RLQSGD (cubic)')
	plt.plot(iteration[start_iter:],result_avg["lqsgd_var_list"][start_iter:],label='LQSGD (cubic)')
	plt.plot(iteration[start_iter:],result_avg["qsgd_var_list"][start_iter:],label='QSGD, qlevel = {}'.format(qlevel))
	plt.plot(iteration[start_iter:],result_avg["hadamard_var_list"][start_iter:],label='Hadamard, qlevel = {}'.format(qlevel))
	if(args.l2):	
		plt.plot(iteration[start_iter:],result_avg["qsgdl2_var_list"][start_iter:],label='QSGD-L2, qlevel = {}'.format(qlevel))
	plt.xlabel('iteration')
	plt.ylabel('$\log_{10}||EST - \\nabla||_2^2$' if args.log_scale else '$||EST - \\nabla||_2^2$')
	plt.title('Regression Variance\nS = {}, n = {}, d = {}, batch_size = {}'.format(n,N,d,batch_size), fontsize=15)
	plt.legend()

	if(args.save):
		if not os.path.isdir('out'):
			os.makedirs('out')
		if(args.log_scale):
			plt.savefig('out/variance_S_{}_d_{}_log.pdf'.format(args.n,args.d))
		else:
			plt.savefig('out/variance_S_{}_d_{}.pdf'.format(args.n,args.d))
	else:
		plt.show()
	plt.close()

if __name__ == '__main__':
	main()
