''' This experiment is for seeing how well can a machine communicate an unbiased estimate of g_i with sublinear bits. 

We will just have 2 machines M0 and M1. We will simulate and compare the following: 
(a) cubic lattices 
(b) cross polytope scheme 

We will measure how wrong M1 receives M0's quantized gradient, i.e ||Q(g0) - g0||^2
'''

import argparse
import os
import numpy as np
import numpy.linalg as LA
import matplotlib.pyplot as plt
from quantization import LQSGD, vQSGD 


def variance_calc(sample_vectors, expected_vector):
    var = 0
    for vector in sample_vectors:
        var += np.linalg.norm(expected_vector - vector)**2
    return var/len(sample_vectors)

def grad(A,w,b):
	n = A.shape[0]
	return 2*(A.T@A@w - A.T@b)/n

def experiment(n,N,d,lr,batch_size,alpha,B,repetition,iterations,seed):
	# seed
	np.random.seed(seed)	
	
	# print to stdout
	print("-----")
	print("1. data points (S) = {}, d = {}, workers (n) = {}, batch_size = {}".format(n,d,N,batch_size))
	print("2. update rule for y ~ {}*||g0-g1||".format(alpha))
	print("3. bits = {}".format(B))
	print("-----")

	# is normalization required, I think they will already be
	A = np.random.randn(n,d).astype(np.float64)
	w_star = np.random.randn(d,1).astype(np.float64)
	b = A @ w_star 
	
	# quantizers for variance calculations
	vqsgd = vQSGD(repetition = repetition, dimension = d)

	# arrays for plotting
	result = {
		"lqsgd_cubic_var_list": [],
		"vqsgd_var_list": []
	}
	
	# parameters used in experiment
	y = None # parameter for the cubic lattice, would be related to infinity norm
	wrong = 0

	# initialize weights
	w = np.zeros((d,1))

	for it in range(iterations):
		# to shuffle and get the batches
		indices = np.arange(n)
		np.random.shuffle(indices)
		
		# get a random disjoint set of batch gradients and flatten them
		g,g0,g1 = grad(A,w,b), grad(A[indices[0:batch_size]],w,b[indices[0:batch_size]]),grad(A[indices[batch_size:]],w,b[indices[batch_size:]])
		g0,g1 = g0.flatten(),g1.flatten()
		
		# if 0th iteration, set side
		if(it == 0): 
			y = alpha*LA.norm(g0-g1,np.inf) 

		# checking if ||g0-g1||_inf < y
		if(LA.norm(g0-g1,np.inf) > y):
			wrong += 1
		
		# calculating variances in the methods
		times = 20
		vqsgd_var = variance_calc([vqsgd.compress(g0) for _ in range(times)], g0)
		side = (4/(2**(B/d) - 1))*y
		lqsgd_cubic_var = d*side**2/12

		result["lqsgd_cubic_var_list"].append(lqsgd_cubic_var)
		result["vqsgd_var_list"].append(vqsgd_var)
		
		# setting paramters once in 5 iterations
		if(it % 5 == 0):
			y = alpha*LA.norm(g0-g1,np.inf) 
		
		# print stats
		if(it % 50 == 1):
			loss = LA.norm(A@w - b)**2/n
			print("iteration = {}, loss = {}, errors = {}".format(it,loss,wrong))

		# descend along full gradient
		w -= lr*g

	assert(wrong == 0)
	return result


def main():
	parser = argparse.ArgumentParser(description='Sublinear Variance')
	parser.add_argument('--lr', type=float, default=0.01, metavar='LR', help='learning rate (default: 0.01)')
	parser.add_argument('--iterations', type=int, default=60, metavar='ITER', help='iterations (default: 60)')
	parser.add_argument('--nseeds', type=int, default=1, metavar='NSEEDS', help='number of seeds (default: 1')
	parser.add_argument('--n', type=int, default=2**13, metavar='n', help='number of data points (default: 2^13 = 8192)')
	parser.add_argument('--d', type=int, default=2**8, metavar='d', help='dimension (default: 2^8 = 256)')	
	parser.add_argument('--alpha', type=float, default=1.6, metavar='d', help='alpha')	
	parser.add_argument('--save', action='store_true')
	args = parser.parse_args()
	
	# parameters for experiment
	n = args.n # number of data points
	d = args.d # dimension
	N = 2 # number of worker nodes
	batch_size = int(n/2) # batches you want to use for the compare
	B = int(d/2) # target bits
	iterations = args.iterations
	lr = args.lr
	
	# parameters for quantizations (from notes)
	lg2d = int(np.ceil(np.log2(2*d)))
	print(lg2d)
	repetition = B // lg2d
	B = repetition * lg2d
	alpha = args.alpha

	# repeating the experiment for five seeds
	fields = ["lqsgd_cubic_var_list","vqsgd_var_list"]
	seeds = [10*i for i in range(args.nseeds)]
	result_avg = {}
	for field in fields:
		result_avg[field] = np.zeros(iterations)

	for seed in seeds:
		result = experiment(n,N,d,lr,batch_size,alpha,B,repetition,iterations,seed)
		for field in fields:
			result_avg[field] += np.array(result[field])

	for field in fields:
		result_avg[field] /= len(seeds)

	# Plots
	iteration = list(range(iterations))
	# (1) variance wrt full gradient for the two quantization methods
	start_iter = max(0,iterations - 250)
	plt.plot(iteration[start_iter:],result_avg["lqsgd_cubic_var_list"][start_iter:],label='LQSGD (cubic)')
	plt.plot(iteration[start_iter:],result_avg["vqsgd_var_list"][start_iter:],label='vQSGD, repetition = {}'.format(repetition))
	plt.xlabel('iteration')
	plt.ylabel('$||Q(g_0) - g_0||^2$')
	plt.title('Variance of vqsgd vs lqsgd at bits <= {}, S = {}\nn = {}, d = {}, batch_size = {}'.format(B,n,N,d,batch_size),fontsize=15)
	plt.legend()
	if(args.save):
		if not os.path.isdir('out'):
			os.makedirs('out')
		plt.savefig('out/variance_alpha_{}_S_{}_d_{}_repetition_{}.pdf'.format(args.alpha, args.n, args.d, repetition))
	else:
		plt.show()
	plt.close()


if __name__ == '__main__':
	main()
